torch.nn.CrossEntropyLoss用法

文章目录[隐藏]

前言

早上想花一个小时参照网上其他教程,修改模型结构,写一个手写识别数字的出来,结果卡在了这个上面,loss一直降不下来,然后我就去查看了一下CrossEntropyLoss的用法,毕竟分类问题一般都用这个。

代码

引入一个库:

import torch

假如是一个四分类任务,batch为2(只是为了显示简单,举个例子罢了)

logists = torch.randn(2, 4, requires_grad=True)
print(logists)

其实根据这个模型预测出来就是, 第一个样本预测的类别是1, 第二个样本预测的类别是2。
这里我们假设模型足够好,都预测对了,那么其实target就是ground_truth。

target = logists.argmax(dim=-1)

通过查看官方文档 CrossEntropyLoss–PyTorch 1.10.1 document, 可以知道loss有两种算法。
target可one-hot也可以不one-hot。

定义损失函数:

crition = torch.nn.CrossEntropyLoss()

先来看个target_1d版的loss:

crition(logists, target)

再来看个target one-hot版的:
先把target转为one

t_onehot = torch.nn.functional.one_hot(target)

如何是one_hot, 要求target也是浮点类型的,所以t_onehot再调用float()转为浮点类型。

crition(logists, t_onehot.float())

最后发现两种方法其实算出来的loss都是0.5601


另外插一嘴,crossEntropyLoss也可以通过nll_loss实现(如果你去看torch.nn.crossEntropyLoss的源码就会发现官方就是使用torch.nn.functional.nll_loss实现的,只不过模型输出的logists值要先经过log_softmax

版权声明:本文为CSDN博主「Andy Dennis」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_43850253/article/details/122510794

Andy Dennis

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

YOLOv3 YOLOv4 YOLOv5老鼠识别检测告警

前言 在食品安全众多环节中,后厨安全无疑是重中之重。俗话说“民以食为天,食以安为先”,食材新鲜程度如何、加工过程规不规范、厨具是否经过清洁消毒等问题,备受大家关注。 一、为什么需要AI检