Traceback (most recent call last):
  File "/home/aistudio/bnn_pyro_fso_middle_2_16__32.py", line 426, in <module>
    loss = svi.step(datax, datay)
  File "/home/aistudio/external-libraries/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/aistudio/external-libraries/pyro/infer/trace_elbo.py", line 140, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
  File "/home/aistudio/external-libraries/pyro/infer/elbo.py", line 237, in _get_traces
    yield self._get_trace(model, guide, args, kwargs)
  File "/home/aistudio/external-libraries/pyro/infer/trace_elbo.py", line 57, in _get_trace
    model_trace, guide_trace = get_importance_trace(
  File "/home/aistudio/external-libraries/pyro/infer/enum.py", line 60, in get_importance_trace
    guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
  File "/home/aistudio/external-libraries/pyro/poutine/trace_messenger.py", line 216, in get_trace
    self(*args, **kwargs)
  File "/home/aistudio/external-libraries/pyro/poutine/trace_messenger.py", line 198, in __call__
    raise exc from e
  File "/home/aistudio/external-libraries/pyro/poutine/trace_messenger.py", line 191, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/aistudio/external-libraries/pyro/nn/module.py", line 520, in __call__
    result = super().__call__(*args, **kwargs)
  File "/home/aistudio/external-libraries/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/aistudio/external-libraries/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aistudio/external-libraries/pyro/infer/autoguide/guides.py", line 759, in forward
    self._setup_prototype(*args, **kwargs)
  File "/home/aistudio/external-libraries/pyro/infer/autoguide/guides.py", line 940, in _setup_prototype
    self.loc = nn.Parameter(self._init_loc())
  File "/home/aistudio/external-libraries/pyro/infer/autoguide/guides.py", line 669, in _init_loc
    latent = torch.cat(parts)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument tensors in method wrapper_CUDA_cat)
Trace Shapes:
 Param Sites:
Sample Sites:

调试方法

修改

pyro/infer/autoguide/guides.py

def _init_loc(self):
        """
        Creates an initial latent vector using a per-site init function.
        """
        parts = []
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            constrained_value = site["value"].detach()
            unconstrained_value = biject_to(site["fn"].support).inv(constrained_value)
            parts.append(unconstrained_value.reshape(-1))
            print(name, site)
            if constrained_value.device.type == 'cuda':
                print("constrained_value is on CUDA")
            else:
                print("constrained_value is on CPU")
            if unconstrained_value.device.type == 'cuda':
                print("unconstrained_value is on CUDA")
            else:
                print("unconstrained_value is on CPU")      
        latent = torch.cat(parts)
        assert latent.size() == (self.latent_dim,)
        return latent