pytorch CrossEntropyloss使用方法(包括多维度)

pytorch CrossEntropyloss使用方法(包括多维度)

官方文档给出的用法如下:

也就是说,在网络的output要把分类放在第二维,第二维后面的代表的是网络的维度,看起来非常简单,示例代码如下:

loss = nn.CrossEntropyLoss()
input = torch.randn(9, 5, 2)
target = torch.empty(9, 2, dtype = torch.long).random_(5)
output = loss(input, target)