Pytorch RuntimeError 解决办法
问题描述
在Pytorch训练自定义数据集中发生如下错误:
RuntimeError: result type Float can't be cast to the desired output type Long
RuntimeError:结果类型 Float 无法转换为所需的输出类型 Long
1 | loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_weights])) |
问题解决
BCEWithLogitsLoss
要求它的目标是一个float
张量,而不是long
。所以应该通过dtype=torch.float32
指定张量的类型。
将上述代码修改如下:
1 | loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_weights], dtype=torch.float32)) |
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 花猪のBlog!
评论
TwikooWaline