PyTorch 激活函数的实现示例
更新时间:2025年12月24日 09:10:00 作者:byxdaz
激活函数是神经网络中至关重要的组成部分,本文就来详细的介绍一下PyTorch常用的激活函数,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
激活函数是神经网络中至关重要的组成部分,它们为网络引入了非线性特性,使得神经网络能够学习复杂模式。PyTorch 提供了多种常用的激活函数实现。
常用激活函数
1. ReLU (Rectified Linear Unit)
数学表达式:

PyTorch实现:
torch.nn.ReLU(inplace=False)
特点:
- 计算简单高效
- 解决梯度消失问题(正区间)
- 可能导致"神经元死亡"(负区间梯度为0),ReLU 在输入为负时输出恒为 0,导致反向传播中梯度消失,相关权重无法更新14。若神经元长期处于负输入状态,则会永久“死亡”,失去学习能力。
示例:
relu = nn.ReLU() input = torch.tensor([-1.0, 0.0, 1.0, 2.0]) output = relu(input) # tensor([0., 0., 1., 2.])
2. LeakyReLU
数学表达式:

PyTorch实现:
torch.nn.LeakyReLU(negative_slope=0.01, inplace=False)
特点:
- 解决了ReLU的"神经元死亡"问题,通过引入负区间的微小斜率(如 torch.nn.LeakyReLU(negative_slope=0.01)),保留负输入的梯度传播,避免神经元死亡。
- negative_slope通常设为0.01
示例
leaky_relu = nn.LeakyReLU(negative_slope=0.1) input = torch.tensor([-1.0, 0.0, 1.0, 2.0]) output = leaky_relu(input) # tensor([-0.1000, 0.0000, 1.0000, 2.0000])
3. Sigmoid
数学表达式:

PyTorch实现:
torch.nn.Sigmoid()
特点:
- 输出范围(0,1),适合二分类问题
- 容易出现梯度消失问题
- 输出不以0为中心
示例:
sigmoid = nn.Sigmoid() input = torch.tensor([-1.0, 0.0, 1.0, 2.0]) output = sigmoid(input) # tensor([0.2689, 0.5000, 0.7311, 0.8808])
4. Tanh (Hyperbolic Tangent)
数学表达式:

PyTorch实现:
torch.nn.Tanh()
特点:
- 输出范围(-1,1),以0为中心
- 比sigmoid梯度更强
- 仍存在梯度消失问题
示例:
tanh = nn.Tanh() input = torch.tensor([-1.0, 0.0, 1.0, 2.0]) output = tanh(input) # tensor([-0.7616, 0.0000, 0.7616, 0.9640])
5. Softmax
数学表达式:

PyTorch实现:
torch.nn.Softmax(dim=None)
特点:
- 输出为概率分布(和为1)
- 常用于多分类问题的输出层
- dim参数指定计算维度
示例:
softmax = nn.Softmax(dim=1) input = torch.tensor([[1.0, 2.0, 3.0]]) output = softmax(input) # tensor([[0.0900, 0.2447, 0.6652]])
其他激活函数
6. ELU (Exponential Linear Unit)
torch.nn.ELU(alpha=1.0, inplace=False)
7. GELU (Gaussian Error Linear Unit)
torch.nn.GELU()
8. Swish
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)选择指南
- 隐藏层:通常首选ReLU及其变体(LeakyReLU、ELU等)
- 二分类输出层:Sigmoid
- 多分类输出层:Softmax
- 需要负输出的情况:Tanh或LeakyReLU
- Transformer模型:常用GELU
自定义激活函数
PyTorch可以轻松实现自定义激活函数:
class CustomActivation(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.where(x > 0, x, torch.exp(x) - 1)注意事项
- 梯度消失/爆炸问题
- 死亡神经元问题(特别是ReLU)
- 计算效率考虑
- 初始化方法应与激活函数匹配
相关文章
使用Django Form解决表单数据无法动态刷新的两种方法
这篇文章主要介绍了使用Django Form解决表单数据无法动态刷新的两种方法,需要的朋友可以参考下2017-07-07
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
今天小编就为大家分享一篇计算pytorch标准化(Normalize)所需要数据集的均值和方差实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2020-01-01
浅谈opencv自动光学检测、目标分割和检测(连通区域和findContours)
这篇文章主要介绍了浅谈opencv自动光学检测、目标分割和检测(连通区域和findContours),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2020-06-06


最新评论