# Masked batchnorm in PyTorch

If you just want to see the code, it’s in a gist.

## Introduction: Sequence tokenization and padding

Sometimes the fact that all proteins are not the same length is the bane of my existence.

One way we deal with variable-length sequences in machine learning is to pad them all to the same length. For example, if I have the following nucleotide sequences:

```
AGCTAG
AGCTAGA
AGCTA
```

I can right-pad them by introducing a special padding token, let’s call it `p`

.

```
AGCTAGp
AGCTAGA
AGCTApp
```

And then I can use this as a batch input to the machine-learning model of my choice by assigning each nucleotide and the padding token an integer:

```
x = torch.tensor(
[
[0, 2, 1, 3, 0, 2, 4],
[0, 2, 1, 3, 0, 2, 0],
[0, 2, 1, 3, 0, 4, 4]
]
)
```

If I want to put sequences through a transformer or RNN or CNN and I don’t want the model to be affected by the padding, I can also pass in a mask (`mask = (x != 4)`

) and use that mask to set things to zero where appropriate. For example, the PyTorch `Transformer`

class uses this sort of mask (but with a `ByteTensor`

) for its [`src`

/`tgt`

/`mask`

]`_padding_mask`

arguments.

## Trying to extend PyTorch’s batchnorm

Unfortunately, `nn.BatchNorm1d`

doesn’t support this type of masking, so if I zero out padding locations, then my minibatch statistics get artificially lowered by the extra zeros. Given Pytorch’s object-oriented nature, the most elegant way to implement masked batchnorm would be to extend one of their classes and modify the way minibatch statistics are calculated.

Starting at `nn.BatchNorm1d`

, we find that all this class implements is a method for checking input dimensions:

```
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
```

It’s superclass (nn._BatchNorm) has a `forward`

method, which checks whether to use `train`

or `eval`

mode, retrieves the parameters needed to calculate the moving averages, and then calls `F.batch_norm`

. `F.batch_norm`

in turn calls `torch.batch_norm`

. Clicking on that in github leads back to `F.batch_norm`

: I think it must be actually implemented in the lower-level cpp code.

In any case, it looks like there’s no straight-forward way to extend PyTorch’s batchnorm implementation, so time to write it from scratch.

## MaskedBatchNorm1d

Given a `(B, 1, L)`

mask, we first mask and then compute the number of unmasked locations over which to calculate the minibatch statistics:

```
if input_mask is not None:
masked = input * input_mask
n = input_mask.sum()
```

Then calculate the minibatch mean:

```
masked_sum = masked.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True)
current_mean = masked_sum / n
```

And the minibatch variance:

```
current_var = ((masked - current_mean) ** 2)
current_var = current_var.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / n
```

The full module is available as a gist.

## Limitations

Because I didn’t want to dig deeper into PyTorch source, there’s a few limitations here.

- It’s almost certainly not as fast as the native PyTorch implementation.
- If you’re doing multi-GPU training, minibatch statistics won’t be synced across devices as they would be with Apex’s SyncBatchNorm.
- If you’re doing mixed-precision training with Apex, you can’t use level
`O2`

because it won’t detect that this is a batchnorm layer and keep it in float precision.