【Pytorch基础】torch.nn.CrossEntropyLoss损失函数介绍
1 交叉熵的定义
交叉熵主要是用来判定实际的输出与期望的输出的接近程度,为什么这么说呢,举个例子:在做分类的训练的时候,如果一个样本属于第K类,那么这个类别所对应的输出节点的输出值应该为1,而其他节点的输出都为0,即[0,0,1,0,….0,0],这个数组也就是样本的Label,是神经网络最期望的输出结果。也就是说用它来衡量网络的输出与标签的差异,利用这种差异经过反向传播去更新网络参数。参考文献【1】
2 交叉熵的数学原理
Pytorch常用的交叉熵损失函数CrossEntropyLoss()详解
3 Pytorch交叉熵实现
3.1 举个栗子
交叉熵损失,是分类任务中最常用的一个损失函数。在Pytorch中是基于下面的公式实现的。
Loss
(
x
^
,
x
)
=
−
∑
i
=
1
n
x
log
(
x
^
)
\operatorname{Loss}(\hat{x}, x)=-\sum_{i=1}^{n} x \log (\hat{x})
Loss(x^,x)=−i=1∑nxlog(x^)
其中
x
x
x是真实标签,
x
^
\hat{x}
x^ 是预测的类分布(通常是使用softmax将模 型输出转换为概率分布)。
取单个样本举例, 假设
x
1
=
[
0
,
1
,
0
]
x_1=[0, 1, 0]
x1=[0,1,0], 模型预测样本
x
1
x_1
x1的概率为
x
1
^
=
[
0.1
,
0.5
,
0.4
]
\hat{x_1}=[0.1, 0.5, 0.4]
x1^=[0.1,0.5,0.4](因为是分布, 所以属于各个类的和为1)。则样本的损失计算如下所示:
Loss ( x 1 ^ , x 1 ) = − 0 × log ( 0.1 ) − 1 × log ( 0.5 ) − 0 × log ( 0.4 ) = log ( 0.5 ) \operatorname{Loss}(\hat{x_1}, x_1)=-0 \times \log (0.1)-1 \times \log (0.5)-0 \times \log (0.4)=\log (0.5) Loss(x1^,x1)=−0×log(0.1)−1×log(0.5)−0×log(0.4)=log(0.5)
更详细的多分类交叉熵损失函数的例子可以参考文献【4】
3.2 Pytorch实现
实际使用中需要注意几点:
- torch.nn.CrossEntropyLoss(input, target)中的标签target使用的不是one-hot形式,而是类别的序号。形如 target = [1, 3, 2] 表示3个样本分别属于第1类、第3类、第2类。(单标签多分类问题)
- torch.nn.CrossEntropyLoss(input, target)的input是没有归一化的每个类的得分,而不是softmax之后的分布。
输入的形式大概如下所示:
import torch
target = [1, 3, 2]
input_ = [[0.13, -0.18, 0.87],
[0.25, -0.04, 0.32],
[0.24, -0.54, 0.53]]
# 然后就将他们扔到CrossEntropyLoss函数中,就可以得到损失。
loss_item = torch.nn.CrossEntropyLoss()
loss = loss_item(input, target)
CrossEntropyLoss函数里面的实现,如下所示:
def forward(self, input, target):
return F.cross_entropy(input, target, weight=self.weight,
ignore_index=self.ignore_index, reduction=self.reduction)
是调用的torch.nn.functional(俗称F)中的cross_entropy()函数。
此处需要区分一下:torch.nn.Module 和 torch.nn.functional(俗称F)中损失函数的区别。Module的损失函数例如CrossEntropyLoss、NLLLoss等是封装之后的损失函数类,是一个类,因此其中的变量可以自动维护。经常是对F中的函数的封装。而F中的损失函数只是单纯的函数。
下面看一下F.cross_entropy函数
3.3 F.cross_entropy
- input:预测值,(batch,dim),这里dim就是要分类的总类别数
- target:真实值,(batch),这里为啥是1维的?因为真实值并不是用one-hot形式表示,而是直接传类别id。
- weight:指定权重,(dim),可选参数,可以给每个类指定一个权重。通常在训练数据中不同类别的样本数量差别较大时,可以使用权重来平衡。
- ignore_index:指定忽略一个真实值,(int),也就是手动忽略一个真实值。
- reduction:在[none, mean, sum]中选,string型。none表示不降维,返回和target相同形状;mean表示对一个batch的损失求均值;sum表示对一个batch的损失求和。
其中参数weight、ignore_index、reduction要在实例化CrossEntropyLoss对象时指定,例如:
loss = torch.nn.CrossEntropyLoss(reduction='none')
F中的cross_entropy的实现
return nll_loss(log_softmax(input, dim=1), target, weight, None, ignore_index, None, reduction)
可以看到就是先调用log_softmax,再调用nll_loss。log_softmax就是先softmax再取log。
4 参考文献
[1]Pytorch常用损失函数拆解
[2]Pytorch常用的交叉熵损失函数CrossEntropyLoss()详解
[3]负对数似然(negative log-likelihood)
[4]损失函数|交叉熵损失函数