残差网络(Residual Network)是一种广泛使用的深度神经网络结构。
它的基本思想是:
通过在相邻层之间添加“残差连接”,实现信息的直接传递,避免信息在经过多层处理后消失或爆炸。
一个简单的残差块,当输入x经过两层网络计算得到F(x)时,如果直接将x添加到F(x)中,那么最终的输出就是x + F(x),也就是输入x和理论上F(x)应该逼近的目标值之和。
这实际上使得网络只需要学习输入x与输出y之间的差值(残差)F(x),而不是直接学习 x 到 y 的映射关系。这减轻了深层网络的参数优化难度,有助于解决梯度消失和爆炸问题。
代码示例:
python
import torch
import torch.nn as nn
# 残差块
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=False):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
if downsample:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels)
)
else:
self.downsample = None
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
可以看到,残差块通过residual连接实现了输入的直接传递,输出是输入与block内两层网络输出的和。这使得网络只需要学习输入和输出的差值,简化了优化难度。
残差网络通过堆叠多个残差块,实现了信息的直接传播,解决了深层网络训练中的梯度消失问题,使得网络可以继续加深,达到上百层的深度。这是残差网络可以训练超深度模型的关键。
所以,残差网络通过残差学习和信息直接传播的思想,成功地训练了深度达上百层的模型,解决了深层网络训练的瓶颈,大大提高了模型的表征能力。理解残差网络的原理与结构,可以帮助我们构建更深更强的神经网络模型。