前言
前提
前几天在写一段代码的时候,遇到了nan
错误,但是一直么有找到问题所在。我的代码主要是将一个batch 中的entity的embedding送到自定义的2层BertModel中。
可能导致的原因
- 送入到BertModel中仅传入了 hidden_states,没有传入attention_mask。但是后来搞定这个bug之后还是有nan值的村子。
- 于是按行debug,发现问题在于送入到BertEncoder中的 hidden_states 中有nan的存在。
我们可以用下面这行代码检查tensor中是否有nan 值。
torch.any(torch.isnan(entity_bank))
这个 entity_bank
便是一个待检测的向量。于是对应解决这个nan值的生成即可。
但奇怪的是,使用下图中注释的一行就会有nan 的问题,但是用第二行代码则没有这个问题。