layers.py
1.15 KB
import torch
from torch import nn
import torch.nn.functional as F
class FixedBatchNorm2d(nn.Module):
'BatchNorm2d where the batch statistics and the affine parameters are fixed'
def __init__(self, n):
super().__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def forward(self, x):
return F.batch_norm(x, running_mean=self.running_mean, running_var=self.running_var, weight=self.weight, bias=self.bias)
def convert_fixedbn_model(module):
'Convert batch norm layers to fixed'
mod = module
if isinstance(module, nn.BatchNorm2d):
mod = FixedBatchNorm2d(module.num_features)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_fixedbn_model(child))
return mod