ResNet(Residual Network)是一种深度残差神经网络,由微软研究院提出。它的关键在于使用残差结构,解决在训练深度神经网络时出现的梯度消失问题。
ResNet的基本块是残差块,主要有两种类型:
- ResNet V1:
x = ConvLayer(x)
x = ConvLayer(x)
x = x + shortcut # shortcut直连x
x = Activation(x)
- ResNet V2:
x = ConvLayer(x)
x = ConvLayer(x)
shortcut = ConvLayer(x) # shortcut也做卷积变换
x = x + shortcut
x = Activation(x)
ResNet通过在主路径上加残差块,实现深层网络结构。主要有以下优点:
- 残差连接可以跨过多层网络直接传播信号,有效地解决梯度消失问题,利于网络训练。
- 残差块内包含若干个卷积层,但输出和输入尺度保持一致,便于残差相加。
- 残差网络极深,152层的ResNet在ImageNet上取得 state-of-the-art的结果。网络越深代表提取的特征越抽象和语义化。
- 残差连接引入的 shortcut路径使得网络中低层特征也可以直接到达最终输出,这种多阶特征结合的方式可以产生更强的表示能力。
- 残差块使得网络每一部分的功能非常清晰,有利于训练和理解。
ResNet代码示例:
python
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(out_channels, out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
所以,ResNet通过残差 Learning 的方式成功地训练了152层的超深度神经网络,大大推动了CNN在计算机视觉的进展。它已成为图像分类和目标检测等视觉任务的基础网络结构之一。