• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python pyro.sample函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中pyro.sample函数的典型用法代码示例。如果您正苦于以下问题:Python sample函数的具体用法?Python sample怎么用?Python sample使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了sample函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: model

 def model():
     p2 = torch.tensor(torch.ones(2) / 2)
     p3 = torch.tensor(torch.ones(3) / 3)
     x2 = pyro.sample("x2", dist.OneHotCategorical(p2))
     x3 = pyro.sample("x3", dist.OneHotCategorical(p3))
     assert x2.shape == torch.Size([2]) + iarange_shape + p2.shape
     assert x3.shape == torch.Size([3, 1]) + iarange_shape + p3.shape
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_valid_models.py


示例2: bernoulli_normal_model

def bernoulli_normal_model():
    bern_0 = pyro.sample('bern_0', dist.Bernoulli(torch.zeros(1) * 1e-2))
    loc = torch.ones(1) if bern_0.item() else -torch.ones(1)
    normal_0 = torch.ones(1)
    pyro.sample('normal_0', dist.Normal(loc, torch.ones(1) * 1e-2),
                obs=normal_0)
    return [bern_0, normal_0]
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_properties.py


示例3: guide

 def guide():
     pyro.module("mymodule", pt_guide)
     mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log)
     sigma = torch.pow(tau_q, -0.5)
     pyro.sample("mu_latent",
                 dist.Normal(mu_q, sigma, reparameterized=reparameterized),
                 baseline=dict(use_decaying_avg_baseline=True))
开发者ID:Magica-Chen,项目名称:pyro,代码行数:7,代码来源:test_tracegraph_elbo.py


示例4: _register_param

    def _register_param(self, param, mode="model"):
        """
        Registers a parameter to Pyro. It can be seen as a wrapper for
        :func:`pyro.param` and :func:`pyro.sample` primitives.

        :param str param: Name of the parameter.
        :param str mode: Either "model" or "guide".
        """
        if param in self._fixed_params:
            self._registered_params[param] = self._fixed_params[param]
            return
        prior = self._priors.get(param)
        if self.name is None:
            param_name = param
        else:
            param_name = param_with_module_name(self.name, param)

        if prior is None:
            constraint = self._constraints.get(param)
            default_value = getattr(self, param)
            if constraint is None:
                p = pyro.param(param_name, default_value)
            else:
                p = pyro.param(param_name, default_value, constraint=constraint)
        elif mode == "model":
            p = pyro.sample(param_name, prior)
        else:  # prior != None and mode = "guide"
            MAP_param_name = param_name + "_MAP"
            # TODO: consider to init parameter from a prior call instead of mean
            MAP_param = pyro.param(MAP_param_name, prior.mean.detach())
            p = pyro.sample(param_name, dist.Delta(MAP_param))

        self._registered_params[param] = p
开发者ID:lewisKit,项目名称:pyro,代码行数:33,代码来源:util.py


示例5: iarange_model

def iarange_model(subsample_size):
    loc = torch.zeros(20)
    scale = torch.ones(20)
    with pyro.iarange('iarange', 20, subsample_size) as batch:
        pyro.sample("x", dist.Normal(loc[batch], scale[batch]))
        result = list(batch.data)
    return result
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_mapdata.py


示例6: guide

 def guide(self, x):
     # register PyTorch module `encoder` with Pyro
     pyro.module("encoder", self.encoder)
     # use the encoder to get the parameters used to define q(z|x)
     z_mu, z_sigma = self.encoder.forward(x)
     # sample the latent code z
     pyro.sample("latent", dist.normal, z_mu, z_sigma)
开发者ID:Magica-Chen,项目名称:pyro,代码行数:7,代码来源:vae.py


示例7: guide

    def guide(self, xs, ys=None):
        """
        The guide corresponds to the following:
        q(y|x) = categorical(alpha(x))              # infer digit from an image
        q(z|x,y) = normal(mu(x,y),sigma(x,y))       # infer handwriting style from an image and the digit
        mu, sigma are given by a neural network `encoder_z`
        alpha is given by a neural network `encoder_y`
        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.iarange("independent"):

            # if the class label (the digit) is not supervised, sample
            # (and score) the digit with the variational distribution
            # q(y|x) = categorical(alpha(x))
            if ys is None:
                alpha = self.encoder_y.forward(xs)
                ys = pyro.sample("y", dist.categorical, alpha)

            # sample (and score) the latent handwriting-style with the variational
            # distribution q(z|x,y) = normal(mu(x,y),sigma(x,y))
            mu, sigma = self.encoder_z.forward([xs, ys])
            zs = pyro.sample("z", dist.normal, mu, sigma)   # noqa: F841
开发者ID:Magica-Chen,项目名称:pyro,代码行数:26,代码来源:ss_vae_M2.py


示例8: guide

 def guide(subsample_size):
     mu = pyro.param("mu", lambda: Variable(torch.zeros(len(data)), requires_grad=True))
     sigma = pyro.param("sigma", lambda: Variable(torch.ones(1), requires_grad=True))
     with pyro.iarange("data", len(data), subsample_size) as ind:
         mu = mu[ind]
         sigma = sigma.expand(subsample_size)
         pyro.sample("z", dist.Normal(mu, sigma, reparameterized=reparameterized))
开发者ID:Magica-Chen,项目名称:pyro,代码行数:7,代码来源:test_gradient.py


示例9: sample_zs

 def sample_zs(name, width):
     alpha_z_q = pyro.param("log_alpha_z_q_%s" % name,
                            lambda: rand_tensor((x_size, width), self.alpha_init, self.sigma_init))
     mean_z_q = pyro.param("log_mean_z_q_%s" % name,
                           lambda: rand_tensor((x_size, width), self.mean_init, self.sigma_init))
     alpha_z_q, mean_z_q = self.softplus(alpha_z_q), self.softplus(mean_z_q)
     pyro.sample("z_%s" % name, Gamma(alpha_z_q, alpha_z_q / mean_z_q).independent(1))
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:sparse_gamma_def.py


示例10: model

 def model(num_particles):
     with pyro.iarange("particles", num_particles):
         q3 = pyro.param("q3", torch.tensor(pi3, requires_grad=True))
         q4 = pyro.param("q4", torch.tensor(0.5 * (pi1 + pi2), requires_grad=True))
         z = pyro.sample("z", dist.Normal(q3, 1.0).expand_by([num_particles]))
         zz = torch.exp(z) / (1.0 + torch.exp(z))
         pyro.sample("y", dist.Bernoulli(q4 * zz))
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_enum.py


示例11: guide

 def guide(num_particles):
     q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True))
     q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True))
     with pyro.iarange("particles", num_particles):
         z = pyro.sample("z", dist.Normal(q2, 1.0).expand_by([num_particles]))
         zz = torch.exp(z) / (1.0 + torch.exp(z))
         pyro.sample("y", dist.Bernoulli(q1 * zz))
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_enum.py


示例12: guide

 def guide(subsample):
     loc = pyro.param("loc", lambda: torch.zeros(len(data), requires_grad=True))
     scale = pyro.param("scale", lambda: torch.tensor([1.0], requires_grad=True))
     with pyro.iarange("particles", num_particles):
         with pyro.iarange("data", len(data), subsample_size, subsample) as ind:
             loc_ind = loc[ind].unsqueeze(-1).expand(-1, num_particles)
             pyro.sample("z", Normal(loc_ind, scale))
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_gradient.py


示例13: guide_step

    def guide_step(self, t, n, prev, inputs):

        rnn_input = torch.cat((inputs['embed'], prev.z_where, prev.z_what, prev.z_pres), 1)
        h, c = self.rnn(rnn_input, (prev.h, prev.c))
        z_pres_p, z_where_loc, z_where_scale = self.predict(h)

        # Compute baseline estimates for discrete choice z_pres.
        bl_value, bl_h, bl_c = self.baseline_step(prev, inputs)

        # Sample presence.
        z_pres = pyro.sample('z_pres_{}'.format(t),
                             dist.Bernoulli(z_pres_p * prev.z_pres).independent(1),
                             infer=dict(baseline=dict(baseline_value=bl_value.squeeze(-1))))

        sample_mask = z_pres if self.use_masking else torch.tensor(1.0)

        z_where = pyro.sample('z_where_{}'.format(t),
                              dist.Normal(z_where_loc + self.z_where_loc_prior,
                                          z_where_scale * self.z_where_scale_prior)
                                  .mask(sample_mask)
                                  .independent(1))

        # Figure 2 of [1] shows x_att depending on z_where and h,
        # rather than z_where and x as here, but I think this is
        # correct.
        x_att = image_to_window(z_where, self.window_size, self.x_size, inputs['raw'])

        # Encode attention windows.
        z_what_loc, z_what_scale = self.encode(x_att)

        z_what = pyro.sample('z_what_{}'.format(t),
                             dist.Normal(z_what_loc, z_what_scale)
                                 .mask(sample_mask)
                                 .independent(1))
        return GuideState(h=h, c=c, bl_h=bl_h, bl_c=bl_c, z_pres=z_pres, z_where=z_where, z_what=z_what)
开发者ID:lewisKit,项目名称:pyro,代码行数:35,代码来源:air.py


示例14: model

 def model(self, data):
     loc = self.loc_0
     lambda_prec = self.lambda_prec
     for i in range(1, self.chain_len + 1):
         loc = pyro.sample('loc_{}'.format(i),
                           dist.Normal(loc=loc, scale=lambda_prec))
     pyro.sample('obs', dist.Normal(loc, lambda_prec), obs=data)
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_hmc.py


示例15: guide

 def guide():
     p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
     outer_irange = pyro.irange("irange_0", 3, subsample_size)
     inner_irange = pyro.irange("irange_1", 3, subsample_size)
     for j in inner_irange:
         for i in outer_irange:
             pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_valid_models.py


示例16: model

    def model(self):
        self.set_mode("model")

        Xu = self.get_param("Xu")
        u_loc = self.get_param("u_loc")
        u_scale_tril = self.get_param("u_scale_tril")

        M = Xu.shape[0]
        Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M, M)) * self.jitter
        Luu = Kuu.potrf(upper=False)

        zero_loc = Xu.new_zeros(u_loc.shape)
        u_name = param_with_module_name(self.name, "u")
        if self.whiten:
            Id = torch.eye(M, out=Xu.new_empty(M, M))
            pyro.sample(u_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Id)
                            .independent(zero_loc.dim() - 1))
        else:
            pyro.sample(u_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Luu)
                            .independent(zero_loc.dim() - 1))

        f_loc, f_var = conditional(self.X, Xu, self.kernel, u_loc, u_scale_tril,
                                   Luu, full_cov=False, whiten=self.whiten,
                                   jitter=self.jitter)

        f_loc = f_loc + self.mean_function(self.X)
        if self.y is None:
            return f_loc, f_var
        else:
            with poutine.scale(None, self.num_data / self.X.shape[0]):
                return self.likelihood(f_loc, f_var, self.y)
开发者ID:lewisKit,项目名称:pyro,代码行数:33,代码来源:vsgp.py


示例17: sample_ws

 def sample_ws(name, width):
     alpha_w_q = pyro.param("log_alpha_w_q_%s" % name,
                            lambda: rand_tensor((width), self.alpha_init, self.sigma_init))
     mean_w_q = pyro.param("log_mean_w_q_%s" % name,
                           lambda: rand_tensor((width), self.mean_init, self.sigma_init))
     alpha_w_q, mean_w_q = self.softplus(alpha_w_q), self.softplus(mean_w_q)
     pyro.sample("w_%s" % name, Gamma(alpha_w_q, alpha_w_q / mean_w_q))
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:sparse_gamma_def.py


示例18: model

    def model(self):
        self.set_mode("model")

        f_loc = self.get_param("f_loc")
        f_scale_tril = self.get_param("f_scale_tril")

        N = self.X.shape[0]
        Kff = self.kernel(self.X) + (torch.eye(N, out=self.X.new_empty(N, N)) *
                                     self.jitter)
        Lff = Kff.potrf(upper=False)

        zero_loc = self.X.new_zeros(f_loc.shape)
        f_name = param_with_module_name(self.name, "f")

        if self.whiten:
            Id = torch.eye(N, out=self.X.new_empty(N, N))
            pyro.sample(f_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Id)
                            .independent(zero_loc.dim() - 1))
            f_scale_tril = Lff.matmul(f_scale_tril)
        else:
            pyro.sample(f_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Lff)
                            .independent(zero_loc.dim() - 1))

        f_var = f_scale_tril.pow(2).sum(dim=-1)

        if self.whiten:
            f_loc = Lff.matmul(f_loc.unsqueeze(-1)).squeeze(-1)
        f_loc = f_loc + self.mean_function(self.X)
        if self.y is None:
            return f_loc, f_var
        else:
            return self.likelihood(f_loc, f_var, self.y)
开发者ID:lewisKit,项目名称:pyro,代码行数:34,代码来源:vgp.py


示例19: model

def model(data):
	
	idk = torch.tensor(4.0) # where is my neural network?
	idkb = torch.tensor(4.0)
	genelambdas = dist.Gamma(idka, idkb, batch_size = 19795)
	for celltype in range(data.size(0)): # this one's 56 right?
		with iarange('observe_{}'.format(celltype)):
			pyro.sample('indiv', dist.Poisson(genelambdas), obs=data[celltype])
开发者ID:anihamde,项目名称:cs287-s18,代码行数:8,代码来源:vi_pyro.py


示例20: irange_model

def irange_model(subsample_size):
    loc = torch.zeros(20)
    scale = torch.ones(20)
    result = []
    for i in pyro.irange('irange', 20, subsample_size):
        pyro.sample("x_{}".format(i), dist.Normal(loc[i], scale[i]))
        result.append(i)
    return result
开发者ID:lewisKit,项目名称:pyro,代码行数:8,代码来源:test_mapdata.py



注:本文中的pyro.sample函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python poutine.trace函数代码示例发布时间:2022-05-27
下一篇:
Python pyro.param函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap