Batch Normalization

Table of Contents

1. Batch Normalization

1.1. Overview

我们调节 \(w\) 的目的是为了保证输出在一定的范围内, 例如符合标准正态分布.

batch normalization 不调整 \(w\): 它对输出作 normalization, 保证输出是标准正态分布, batch normalization 与 feature scaling 时的 normalization 类似, `the mean and standard-deviation are calculated per-dimension over the mini-batches`

batch normalization 的代码基本上为:

mu = np.mean(X, axis=-1, keepdims=True)
std = np.std(X, axis=-1, keepdims=True)
X_normed = (X - mu) / (std + self.epsilon)
out = self.gamma * X_normed + self.beta

其中最后一步称为 scale and shift, self.gamma 和 self.beta 是可学习的参数, 为了调整 normalization 的结果.

以 relu 为例, 若直接使用 X_normed, 会导致有一半的 x 被丢弃 (X_normed 有 50% 的概率 < 0). 如果通过 scale and shift 能把 X_normed 向正方向移动一下, 就没有问题了.

import matplotlib.pyplot as plt
import torch

N = 100
plt.style.use("default")

# plt.ylim(0, 40)


def init_weight(in_features, out_features):
    return torch.nn.Parameter(torch.randn(in_features, out_features))


class Layer(torch.nn.Module):
    def __init__(self, in_features, out_features, n):
        super().__init__()
        self.n = n
        self.w = init_weight(in_features, out_features)
        self.bn = torch.nn.BatchNorm1d(in_features)

    def forward(self, input):
        ret = torch.matmul(input, self.w)
        ret = self.bn(ret)
        ret = torch.nn.functional.sigmoid(ret)
        plt.subplot(1, 10, self.n)
        plt.xlim(0,1)
        plt.hist(ret.detach().numpy().reshape(-1))
        return ret


def train():
    net = torch.nn.Sequential()

    for i in range(10):
        net.add_module("linear%d" % (i), Layer(N, N, i + 1))

    x = torch.rand(10, N)
    net(x)
    plt.show()


train()

batch_normal.png

1.2. DNN 的 batchnorm

DNN 的 batchnorm 是在 batch 的维度上进行平均, 即:

batch1: a b c d e
batch2: x y z m n

batchnorm 后为:
batch1: a-u1, b-u2, c-u3...
batch2: x-u1, b-u2, c-u3...

u1 是 mean(a,x), u2 是 mean(b-y)...

mean 的个数和 output 的 shape 相等

1.3. CNN 的 batchnorm

把一个 channel 看做整体, CNN 的 batchnorm 与 DNN 类似, 即:

batch1: channel1:a b c  channel2: a' b' c'
                 d e f            d' e' f'

batch2: channel1:x y z  channel2: x' y' z'
                 m n p            m' n' p'

batchnorm 后为:
batch1: channel1: a-u1 b-u1 c-u1  channel2: a'-u2 b'-u2...
                  ...                       ....
batch2: channel1: x-u1 y-u1 z-u1  channel2: x'-u2 y'-u2...                  
                  ...                       ...

u1 是  mean(a,b,c,d,e,f,x,y,z,m,n,p)
u2 是  mean(a',b',c',d',e',f',x',y',z',m',n',p')

考虑到同一个 channel 的数据都是用同一个 kernel 计算出来的, mean 的个数和 kernel 的个数 (或 output channel 个数) 相同, 即以 channel 为单位计算 mean

tensorflow 的 BatchNormalization 函数接受一个参数 axis, 默认为 -1, 对于 DNN, -1 即 batch 之外那个 axis, 对于 CNN, -1 为 channel.

BatchNormalization(axis=-1) 与 sum(axis=-1) 中 axis 的意义不同: sum 的 axis 指去掉这一维, BatchNormalization 的 axis 指保留这一维

1.4. batchnorm->relu vs. relu->batchnorm

1.4.1. batchnorm->relu 的优点

tflite 使用 batchnorm->relu 时, 可以把 batchnorm 和 relu 融合 (fuse) 到前面一层 cnn 或 dnn, 但使用 relu->batchnorm 时无法融合, 因为 batchorm 在 inference 时相当于对 input 的线性变换, 如果它的前一层也是一个线性变换, 而两者可以很简单的合并成一个线性变换, 但两者之前加上一个 relu 就没法操作了, 因为 relu 并不是线性的

另一方面, relu 放在 batchnorm 之前可能会出来 BatchNorm 导致很大的量化误差 的问题

1.4.2. relu->batchnorm 优点

从实践上看, relu->batchnorm 训练出来的模型的性能更好一点

1.5. Layer Normalization

  • batchnorm 是使用 axis(-1) (即 channel) 做 normalization
  • layernorm 使用 axis(0) (即 batch) 做 normalization

norm.png

使用 layernorm 的原因有两个:

  • batchnorm 需要较大的 batch size 才能工作
  • rnn 里同一个 batch 里样本的大小可能不同

Backlinks

Attention (Attention > Transformer > Add&Norm): - Add 是 ResNet 中的 skip connection - Norm 是 Layer Normalization

Backlinks

Weight Initialization (Weight Initialization > Batch Normalization): Batch Normalization

Author: [email protected]
Date: 2021-08-31 Tue 00:00
Last updated: 2022-02-07 Mon 20:26

知识共享许可协议