pyro bug ,
dist.gama分布 似乎不支持 cuda,否则会报错,数值太小的问题
文档
glaringlee.github.io/_modules/torch/distributions/gamma.html
官方文档
pytorch.org/docs/stable/genindex.html
pytorch.org/docs/stable/distributions.html#torch.distributions.gamma.Gamma.arg_constraints
Probability distributions - torch.distributions — PyTorch 2.4 documentation
gamma分布要支持cuda的解决办法 5 1 要写成 [0.5] [1.0] ,需要带小数点和 list格式
pyro 的写法
sigma = pyro.sample("sigma", dist.Gamma(torch.tensor([0.5], device=self.device), torch.tensor([1.0], device=self.device)))
pytorch的写法
Creates an inverse gamma distribution parameterized by concentration and rate where:
X ~ Gamma(concentration, rate) Y = 1 / X ~ InverseGamma(concentration, rate)
Example:
>>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0])) >>> m.sample() tensor([ 1.2953])
www.allaboutlean.com
图片来自 AllAboutLean.com – Organize your Industry 管理者个人博客
mu = self.layers[-1](x).squeeze() # hidden --> output
sigma = pyro.sample("sigma", dist.Gamma(torch.tensor(.5, device=self.device), torch.tensor(1, device=self.device))) # infer the response noise
mu = mu.clone().detach().to(self.device)
sigma_squared = (sigma * sigma).clone().detach().to(self.device)
pyro的dist,gama 其实就是pytorch.dist.gama的扩展
K:\ProgramData\Anaconda3\envs\py37\Lib\site-packages\pyro\distributions\torch.py
class Gamma(torch.distributions.Gamma, TorchDistributionMixin):
def conjugate_update(self, other):
"""
EXPERIMENTAL.
"""
assert isinstance(other, Gamma)
concentration = self.concentration + other.concentration - 1
rate = self.rate + other.rate
updated = Gamma(concentration, rate)
def _log_normalizer(d):
c = d.concentration
return d.rate.log() * c - c.lgamma()
log_normalizer = (
_log_normalizer(self) + _log_normalizer(other) - _log_normalizer(updated)
)
return updated, log_normalizer
============= 下面几种写法都不支持 cuda ========
mu = self.layers[-1](x).squeeze() # hidden --> output
sigma = pyro.sample("sigma", dist.Gamma(.5, 1)) # infer the response noise
mu = mu.clone().detach().to(self.device)
sigma_squared = (sigma * sigma).clone().detach().to(self.device)
mu = self.layers[-1](x).squeeze() # hidden --> output
sigma = pyro.sample("sigma", dist.Gamma(torch.tensor(.5, device=self.device), torch.tensor(1, device=self.device))) # infer the response noise
mu = mu.clone().detach().to(self.device)
sigma_squared = (sigma * sigma).clone().detach().to(self.device)
调试过程,修改程序包中的代码
pyro/infer/autoguide/guides.py
def _init_loc(self):
"""
Creates an initial latent vector using a per-site init function.
"""
parts = []
for name, site in self.prototype_trace.iter_stochastic_nodes():
constrained_value = site["value"].detach()
unconstrained_value = biject_to(site["fn"].support).inv(constrained_value)
parts.append(unconstrained_value.reshape(-1))
print(name, site)
if constrained_value.device.type == 'cuda':
print("constrained_value is on CUDA")
else:
print("constrained_value is on CPU")
if unconstrained_value.device.type == 'cuda':
print("unconstrained_value is on CUDA")
else:
print("unconstrained_value is on CPU")
latent = torch.cat(parts)
assert latent.size() == (self.latent_dim,)
return latent
UserWarning: Encountered +inf: log_prob_sum at site 'sigma'
warn_if_inf( | 0/15000 [00:20<?, ?it/s, loss=-inf]
Traceback (most recent call last):
File "/home/aistudio/bnn_pyro_fso_middle_2_16__32.py", line 426, in <module>
loss = svi.step(datax, datay)
File "/home/aistudio/external-libraries/pyro/infer/svi.py", line 145, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "/home/aistudio/external-libraries/pyro/infer/trace_elbo.py", line 140, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
File "/home/aistudio/external-libraries/pyro/infer/elbo.py", line 237, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
File "/home/aistudio/external-libraries/pyro/infer/trace_elbo.py", line 57, in _get_trace
model_trace, guide_trace = get_importance_trace(
File "/home/aistudio/external-libraries/pyro/infer/enum.py", line 75, in get_importance_trace
model_trace.compute_log_prob()
File "/home/aistudio/external-libraries/pyro/poutine/trace_struct.py", line 264, in compute_log_prob
log_p = site["fn"].log_prob(
File "/home/aistudio/external-libraries/torch/distributions/independent.py", line 108, in log_prob
log_prob = self.base_dist.log_prob(value)
File "/home/aistudio/external-libraries/torch/distributions/normal.py", line 89, in log_prob
- math.log(math.sqrt(2 * math.pi))
KeyboardInterrupt