前言
早上想花一个小时参照网上其他教程,修改模型结构,写一个手写识别数字的出来,结果卡在了这个上面,loss一直降不下来,然后我就去查看了一下CrossEntropyLoss的用法,毕竟分类问题一般都用这个。
代码
引入一个库:
import torch
假如是一个四分类任务,batch为2(只是为了显示简单,举个例子罢了)
logists = torch.randn(2, 4, requires_grad=True)
print(logists)
其实根据这个模型预测出来就是, 第一个样本预测的类别是1, 第二个样本预测的类别是2。
这里我们假设模型足够好,都预测对了,那么其实target就是ground_truth。
target = logists.argmax(dim=-1)
通过查看官方文档 CrossEntropyLoss–PyTorch 1.10.1 document, 可以知道loss有两种算法。
target可one-hot也可以不one-hot。
定义损失函数:
crition = torch.nn.CrossEntropyLoss()
先来看个target_1d版的loss:
crition(logists, target)
再来看个target one-hot版的:
先把target转为one
t_onehot = torch.nn.functional.one_hot(target)
如何是one_hot, 要求target也是浮点类型的,所以t_onehot再调用float()转为浮点类型。
crition(logists, t_onehot.float())
最后发现两种方法其实算出来的loss都是0.5601
另外插一嘴,crossEntropyLoss也可以通过nll_loss实现(如果你去看torch.nn.crossEntropyLoss的源码就会发现官方就是使用torch.nn.functional.nll_loss实现的,只不过模型输出的logists值要先经过log_softmax
版权声明:本文为CSDN博主「Andy Dennis」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_43850253/article/details/122510794
暂无评论