pytorch
code
def requant_shift(input, requant_scale, shift, bit=12, Sign=False):
input = input.double() * requant_scale / 2**int(shift.data.cpu().numpy()[0])
input = torch.floor(input).float()
转成了numpy之后,在用torch.jit.trace
跟踪模型时,该值就会变成一个常量prim::Constant
,如果没有转,会通过prim::GetAttr
来获取变量。
没有转numpy
转了numpy之后
会有这样的一句提示
TracerWarning: Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
可以看到shift_nums
不见了,变成了一个值为32的常量。