pytorch

code

def requant_shift(input, requant_scale, shift, bit=12, Sign=False):
    input = input.double() * requant_scale / 2**int(shift.data.cpu().numpy()[0])       
    input = torch.floor(input).float()

转成了numpy之后,在用torch.jit.trace跟踪模型时,该值就会变成一个常量prim::Constant,如果没有转,会通过prim::GetAttr来获取变量。

没有转numpy
【pytorch】—— Converting a tensor to a NumPy array might cause the trace to be incorrect. We can‘t rec_深度学习

转了numpy之后
会有这样的一句提示

 TracerWarning: Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

【pytorch】—— Converting a tensor to a NumPy array might cause the trace to be incorrect. We can‘t rec_深度学习_02
可以看到shift_nums不见了,变成了一个值为32的常量。