Masked batchnorm in PyTorch
If you just want to see the code, it’s in a gist.
Introduction: Sequence tokenization and padding
Sometimes the fact that all proteins are not the same length is the bane of my existence.
One way we deal with variable-length sequences in machine learning is to pad them all to the same length. For example, if I have the following nucleotide sequences:
AGCTAG
AGCTAGA
AGCTA
I can right-pad them by introducing a special padding token, let’s call it p
.
AGCTAGp
AGCTAGA
AGCTApp
And then I can use this as a batch input to the machine-learning model of my choice by assigning each nucleotide and the padding token an integer:
x = torch.tensor(
[
[0, 2, 1, 3, 0, 2, 4],
[0, 2, 1, 3, 0, 2, 0],
[0, 2, 1, 3, 0, 4, 4]
]
)
If I want to put sequences through a transformer or RNN or CNN and I don’t want the model to be affected by the padding, I can also pass in a mask (mask = (x != 4)
) and use that mask to set things to zero where appropriate. For example, the PyTorch Transformer
class uses this sort of mask (but with a ByteTensor
) for its [src
/tgt
/mask
]_padding_mask
arguments.
Trying to extend PyTorch’s batchnorm
Unfortunately, nn.BatchNorm1d
doesn’t support this type of masking, so if I zero out padding locations, then my minibatch statistics get artificially lowered by the extra zeros. Given Pytorch’s object-oriented nature, the most elegant way to implement masked batchnorm would be to extend one of their classes and modify the way minibatch statistics are calculated.
Starting at nn.BatchNorm1d
, we find that all this class implements is a method for checking input dimensions:
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
It’s superclass (nn._BatchNorm) has a forward
method, which checks whether to use train
or eval
mode, retrieves the parameters needed to calculate the moving averages, and then calls F.batch_norm
. F.batch_norm
in turn calls torch.batch_norm
. Clicking on that in github leads back to F.batch_norm
: I think it must be actually implemented in the lower-level cpp code.
In any case, it looks like there’s no straight-forward way to extend PyTorch’s batchnorm implementation, so time to write it from scratch.
MaskedBatchNorm1d
Given a (B, 1, L)
mask, we first mask and then compute the number of unmasked locations over which to calculate the minibatch statistics:
if input_mask is not None:
masked = input * input_mask
n = input_mask.sum()
Then calculate the minibatch mean:
masked_sum = masked.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True)
current_mean = masked_sum / n
And the minibatch variance:
current_var = ((masked - current_mean) ** 2)
current_var = current_var.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / n
The full module is available as a gist.
Limitations
Because I didn’t want to dig deeper into PyTorch source, there’s a few limitations here.
- It’s almost certainly not as fast as the native PyTorch implementation.
- If you’re doing multi-GPU training, minibatch statistics won’t be synced across devices as they would be with Apex’s SyncBatchNorm.
- If you’re doing mixed-precision training with Apex, you can’t use level
O2
because it won’t detect that this is a batchnorm layer and keep it in float precision.