Traceback (most recent call last):
File "/home/aistudio/bnn_pyro_fso_middle_2_16__32.py", line 426, in <module>
loss = svi.step(datax, datay)
File "/home/aistudio/external-libraries/pyro/infer/svi.py", line 145, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "/home/aistudio/external-libraries/pyro/infer/trace_elbo.py", line 140, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
File "/home/aistudio/external-libraries/pyro/infer/elbo.py", line 237, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
File "/home/aistudio/external-libraries/pyro/infer/trace_elbo.py", line 57, in _get_trace
model_trace, guide_trace = get_importance_trace(
File "/home/aistudio/external-libraries/pyro/infer/enum.py", line 60, in get_importance_trace
guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
File "/home/aistudio/external-libraries/pyro/poutine/trace_messenger.py", line 216, in get_trace
self(*args, **kwargs)
File "/home/aistudio/external-libraries/pyro/poutine/trace_messenger.py", line 198, in __call__
raise exc from e
File "/home/aistudio/external-libraries/pyro/poutine/trace_messenger.py", line 191, in __call__
ret = self.fn(*args, **kwargs)
File "/home/aistudio/external-libraries/pyro/nn/module.py", line 520, in __call__
result = super().__call__(*args, **kwargs)
File "/home/aistudio/external-libraries/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aistudio/external-libraries/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aistudio/external-libraries/pyro/infer/autoguide/guides.py", line 759, in forward
self._setup_prototype(*args, **kwargs)
File "/home/aistudio/external-libraries/pyro/infer/autoguide/guides.py", line 940, in _setup_prototype
self.loc = nn.Parameter(self._init_loc())
File "/home/aistudio/external-libraries/pyro/infer/autoguide/guides.py", line 669, in _init_loc
latent = torch.cat(parts)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument tensors in method wrapper_CUDA_cat)
Trace Shapes:
Param Sites:
Sample Sites:
调试方法
修改
pyro/infer/autoguide/guides.py
def _init_loc(self):
"""
Creates an initial latent vector using a per-site init function.
"""
parts = []
for name, site in self.prototype_trace.iter_stochastic_nodes():
constrained_value = site["value"].detach()
unconstrained_value = biject_to(site["fn"].support).inv(constrained_value)
parts.append(unconstrained_value.reshape(-1))
print(name, site)
if constrained_value.device.type == 'cuda':
print("constrained_value is on CUDA")
else:
print("constrained_value is on CPU")
if unconstrained_value.device.type == 'cuda':
print("unconstrained_value is on CUDA")
else:
print("unconstrained_value is on CPU")
latent = torch.cat(parts)
assert latent.size() == (self.latent_dim,)
return latent