问题描述

在Pytorch训练自定义数据集中发生如下错误:

RuntimeError: result type Float can't be cast to the desired output type Long

RuntimeError:结果类型 Float 无法转换为所需的输出类型 Long

1
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_weights]))

问题解决

BCEWithLogitsLoss 要求它的目标是一个float 张量,而不是long。所以应该通过dtype=torch.float32指定张量的类型。

将上述代码修改如下:

1
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_weights], dtype=torch.float32))

参考文章:Pytorch 抛出错误 RuntimeError: result type Float can’t be cast to the desired output type Long答案 - 爱码网 (likecs.com)