Python源码示例:torch.diag_embed()
示例1
def init_action_pd(ActionPD, pdparam):
'''
Initialize the action_pd for discrete or continuous actions:
- discrete: action_pd = ActionPD(logits)
- continuous: action_pd = ActionPD(loc, scale)
'''
if 'logits' in ActionPD.arg_constraints: # discrete
action_pd = ActionPD(logits=pdparam)
else: # continuous, args = loc and scale
if isinstance(pdparam, list): # split output
loc, scale = pdparam
else:
loc, scale = pdparam.transpose(0, 1)
# scale (stdev) must be > 0, use softplus with positive
scale = F.softplus(scale) + 1e-8
if isinstance(pdparam, list): # split output
# construct covars from a batched scale tensor
covars = torch.diag_embed(scale)
action_pd = ActionPD(loc=loc, covariance_matrix=covars)
else:
action_pd = ActionPD(loc=loc, scale=scale)
return action_pd
示例2
def __init__(self, nnodes, nfeat, nhid, nclass, gamma=1.0, beta1=5e-4, beta2=5e-4, lr=0.01, dropout=0.6, device='cpu'):
super(RGCN, self).__init__()
self.device = device
# adj_norm = normalize(adj)
# first turn original features to distribution
self.lr = lr
self.gamma = gamma
self.beta1 = beta1
self.beta2 = beta2
self.nclass = nclass
self.nhid = nhid // 2
# self.gc1 = GaussianConvolution(nfeat, nhid, dropout=dropout)
# self.gc2 = GaussianConvolution(nhid, nclass, dropout)
self.gc1 = GGCL_F(nfeat, nhid, dropout=dropout)
self.gc2 = GGCL_D(nhid, nclass, dropout=dropout)
self.dropout = dropout
# self.gaussian = MultivariateNormal(torch.zeros(self.nclass), torch.eye(self.nclass))
self.gaussian = MultivariateNormal(torch.zeros(nnodes, self.nclass),
torch.diag_embed(torch.ones(nnodes, self.nclass)))
self.adj_norm1, self.adj_norm2 = None, None
self.features, self.labels = None, None
示例3
def from_log_cholesky(cls,
log_diag: torch.Tensor,
off_diag: torch.Tensor,
**kwargs) -> 'Covariance':
assert log_diag.shape[:-1] == off_diag.shape[:-1]
batch_dim = log_diag.shape[:-1]
rank = log_diag.shape[-1]
L = torch.diag_embed(torch.exp(log_diag))
idx = 0
for i in range(rank):
for j in range(i):
L[..., i, j] = off_diag[..., idx]
idx += 1
out = cls(size=batch_dim + (rank, rank))
if kwargs:
out = out.to(**kwargs)
perm_shape = tuple(range(len(batch_dim))) + (-1, -2)
out[:] = L.matmul(L.permute(perm_shape))
return out
示例4
def init_action_pd(ActionPD, pdparam):
'''
Initialize the action_pd for discrete or continuous actions:
- discrete: action_pd = ActionPD(logits)
- continuous: action_pd = ActionPD(loc, scale)
'''
args = ActionPD.arg_constraints
if 'logits' in args: # discrete
# for relaxed discrete dist. with reparametrizable discrete actions
pd_kwargs = {'temperature': torch.tensor(1.0)} if hasattr(ActionPD, 'temperature') else {}
action_pd = ActionPD(logits=pdparam, **pd_kwargs)
else: # continuous, args = loc and scale
if isinstance(pdparam, list): # split output
loc, scale = pdparam
else:
loc, scale = pdparam.transpose(0, 1)
# scale (stdev) must be > 0, log-clamp-exp
scale = torch.clamp(scale, min=-20, max=2).exp()
if 'covariance_matrix' in args: # split output
# construct covars from a batched scale tensor
covars = torch.diag_embed(scale)
action_pd = ActionPD(loc=loc, covariance_matrix=covars)
else:
action_pd = ActionPD(loc=loc, scale=scale)
return action_pd
示例5
def __local_curvatures(self, module, g_inp, g_out):
if self.derivatives.hessian_is_zero():
return []
if not self.derivatives.hessian_is_diagonal():
raise NotImplementedError
def positive_part(sign, H):
return clamp(sign * H, min=0)
def diag_embed_multi_dim(H):
"""Convert [N, C_in, H_in, ...] to [N, C_in * H_in * ...,],
embed into [N, C_in * H_in * ..., C_in * H_in = V], convert back
to [V, N, C_in, H_in, ..., V]."""
feature_shapes = H.shape[1:]
V, N = prod(feature_shapes), H.shape[0]
H_diag = diag_embed(H.view(N, V))
# [V, N, C_in, H_in, ...]
shape = (V, N, *feature_shapes)
return einsum("nic->cni", H_diag).view(shape)
def decompose_into_positive_and_negative_sqrt(H):
return [
[diag_embed_multi_dim(positive_part(sign, H).sqrt_()), sign]
for sign in [self.PLUS, self.MINUS]
]
H = self.derivatives.hessian_diagonal(module, g_inp, g_out)
return decompose_into_positive_and_negative_sqrt(H)
示例6
def _sqrt_hessian(self, module, g_inp, g_out):
self._check_2nd_order_parameters(module)
probs = self._get_probs(module)
tau = torchsqrt(probs)
V_dim, C_dim = 0, 2
Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim)
Id_tautau = Id - einsum("nv,nc->vnc", tau, tau)
sqrt_H = einsum("nc,vnc->vnc", tau, Id_tautau)
if module.reduction == "mean":
N = module.input0.shape[0]
sqrt_H /= sqrt(N)
return sqrt_H
示例7
def get_laplacian_nuc_norm(self, A: 'N x C x S'):
N, C, _ = A.size()
# print(A)
AAT = torch.bmm(A, A.permute(0, 2, 1))
ones = torch.ones((N, C, 1), device='cuda')
D = torch.bmm(AAT, ones).view(N, C)
D = torch.diag_embed(D)
return nuclear_norm(D - AAT, sym=True).sum() / N
示例8
def evaluate(self, state, action):
action_mean = self.actor(state)
action_var = self.action_var.expand_as(action_mean)
cov_mat = torch.diag_embed(action_var).to(device)
dist = MultivariateNormal(action_mean, cov_mat)
action_logprobs = dist.log_prob(action)
dist_entropy = dist.entropy()
state_value = self.critic(state)
return action_logprobs, torch.squeeze(state_value), dist_entropy
示例9
def evaluate_lazy_tensor(self, lazy_tensor):
diag = lazy_tensor._diag_tensor._diag
tensor = lazy_tensor._lazy_tensor.tensor
return tensor + torch.diag_embed(diag, dim1=-2, dim2=-1)
示例10
def evaluate_lazy_tensor(self, lazy_tensor):
diag = lazy_tensor._diag_tensor._diag
tensor = lazy_tensor._lazy_tensor.tensor
return tensor + torch.diag_embed(diag, dim1=-2, dim2=-1)
示例11
def evaluate(self):
if self._diag.dim() == 0:
return self._diag
return torch.diag_embed(self._diag)
示例12
def _eval_corr_matrix(self):
tnc = self.task_noise_corr
fac_diag = torch.ones(*tnc.shape[:-1], self.num_tasks, device=tnc.device, dtype=tnc.dtype)
Cfac = torch.diag_embed(fac_diag)
Cfac[..., self.tidcs[0], self.tidcs[1]] = self.task_noise_corr
# squared rows must sum to one for this to be a correlation matrix
C = Cfac / Cfac.pow(2).sum(dim=-1, keepdim=True).sqrt()
return C @ C.transpose(-1, -2)
示例13
def _create_marginal_input(self, batch_shape=torch.Size()):
mat = torch.randn(*batch_shape, 5, 5)
eye = torch.diag_embed(torch.ones(*batch_shape, 5))
return MultivariateNormal(torch.randn(*batch_shape, 5), mat @ mat.transpose(-1, -2) + eye)
示例14
def matrix(self):
"""Matrix form of the butterfly matrix
"""
if not self.complex:
return (torch.diag(self.diag)
+ torch.diag(self.subdiag, -self.diagonal)
+ torch.diag(self.superdiag, self.diagonal))
else: # Use torch.diag_embed (available in Pytorch 1.0) to deal with complex case.
return (torch.diag_embed(self.diag.t(), dim1=0, dim2=1)
+ torch.diag_embed(self.subdiag.t(), -self.diagonal, dim1=0, dim2=1)
+ torch.diag_embed(self.superdiag.t(), self.diagonal, dim1=0, dim2=1))
示例15
def _get_test_posterior(shape, device, dtype, interleaved=True, lazy=False):
mean = torch.rand(shape, device=device, dtype=dtype)
n_covar = shape[-2:].numel()
diag = torch.rand(shape, device=device, dtype=dtype)
diag = diag.view(*diag.shape[:-2], n_covar)
a = torch.rand(*shape[:-2], n_covar, n_covar, device=device, dtype=dtype)
covar = a @ a.transpose(-1, -2) + torch.diag_embed(diag)
if lazy:
covar = NonLazyTensor(covar)
if shape[-1] == 1:
mvn = MultivariateNormal(mean.squeeze(-1), covar)
else:
mvn = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved)
return GPyTorchPosterior(mvn)
示例16
def test_lognorm_to_norm(self):
for dtype in (torch.float, torch.double):
# independent case
mu = torch.tensor([0.25, 0.5, 1.0], device=self.device, dtype=dtype)
diag = torch.tensor([0.5, 2.0, 1.0], device=self.device, dtype=dtype)
Cov = torch.diag_embed((math.exp(1) - 1) * diag)
mu_n, Cov_n = lognorm_to_norm(mu, Cov)
mu_n_expected = torch.tensor(
[-2.73179, -2.03864, -0.5], device=self.device, dtype=dtype
)
diag_expected = torch.tensor(
[2.69099, 2.69099, 1.0], device=self.device, dtype=dtype
)
self.assertTrue(torch.allclose(mu_n, mu_n_expected))
self.assertTrue(torch.allclose(Cov_n, torch.diag_embed(diag_expected)))
# correlated case
Z = torch.zeros(3, 3, device=self.device, dtype=dtype)
Z[0, 2] = math.sqrt(math.exp(1)) - 1
Z[2, 0] = math.sqrt(math.exp(1)) - 1
mu = torch.ones(3, device=self.device, dtype=dtype)
Cov = torch.diag_embed(mu * (math.exp(1) - 1)) + Z
mu_n, Cov_n = lognorm_to_norm(mu, Cov)
mu_n_expected = -0.5 * torch.ones(3, device=self.device, dtype=dtype)
Cov_n_expected = torch.tensor(
[[1.0, 0.0, 0.5], [0.0, 1.0, 0.0], [0.5, 0.0, 1.0]],
device=self.device,
dtype=dtype,
)
self.assertTrue(torch.allclose(mu_n, mu_n_expected, atol=1e-4))
self.assertTrue(torch.allclose(Cov_n, Cov_n_expected, atol=1e-4))
示例17
def test_norm_to_lognorm(self):
for dtype in (torch.float, torch.double):
# Test joint, independent
expmu = torch.tensor([1.0, 2.0, 3.0], device=self.device, dtype=dtype)
expdiag = torch.tensor([1.5, 2.0, 3], device=self.device, dtype=dtype)
mu = torch.log(expmu)
diag = torch.log(expdiag)
Cov = torch.diag_embed(diag)
mu_ln, Cov_ln = norm_to_lognorm(mu, Cov)
mu_ln_expected = expmu * torch.exp(0.5 * diag)
diag_ln_expected = torch.tensor(
[0.75, 8.0, 54.0], device=self.device, dtype=dtype
)
Cov_ln_expected = torch.diag_embed(diag_ln_expected)
self.assertTrue(torch.allclose(Cov_ln, Cov_ln_expected))
self.assertTrue(torch.allclose(mu_ln, mu_ln_expected))
# Test joint, correlated
Cov[0, 2] = 0.1
Cov[2, 0] = 0.1
mu_ln, Cov_ln = norm_to_lognorm(mu, Cov)
Cov_ln_expected[0, 2] = 0.669304
Cov_ln_expected[2, 0] = 0.669304
self.assertTrue(torch.allclose(Cov_ln, Cov_ln_expected))
self.assertTrue(torch.allclose(mu_ln, mu_ln_expected))
# Test marginal
mu = torch.tensor([-1.0, 0.0, 1.0], device=self.device, dtype=dtype)
v = torch.tensor([1.0, 2.0, 3.0], device=self.device, dtype=dtype)
var = 2 * (torch.log(v) - mu)
mu_ln = norm_to_lognorm_mean(mu, var)
var_ln = norm_to_lognorm_variance(mu, var)
mu_ln_expected = torch.tensor(
[1.0, 2.0, 3.0], device=self.device, dtype=dtype
)
var_ln_expected = (torch.exp(var) - 1) * mu_ln_expected ** 2
self.assertTrue(torch.allclose(mu_ln, mu_ln_expected))
self.assertTrue(torch.allclose(var_ln, var_ln_expected))
示例18
def test_round_trip(self):
for dtype in (torch.float, torch.double):
for batch_shape in ([], [2]):
mu = 5 + torch.rand(*batch_shape, 4, device=self.device, dtype=dtype)
a = 0.2 * torch.randn(
*batch_shape, 4, 4, device=self.device, dtype=dtype
)
diag = 3.0 + 2 * torch.rand(
*batch_shape, 4, device=self.device, dtype=dtype
)
Cov = a @ a.transpose(-1, -2) + torch.diag_embed(diag)
mu_n, Cov_n = lognorm_to_norm(mu, Cov)
mu_rt, Cov_rt = norm_to_lognorm(mu_n, Cov_n)
self.assertTrue(torch.allclose(mu_rt, mu, atol=1e-4))
self.assertTrue(torch.allclose(Cov_rt, Cov, atol=1e-4))
示例19
def tobit_adjustment(mean: Tensor,
cov: Tensor,
lower: Optional[Tensor] = None,
upper: Optional[Tensor] = None,
probs: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
assert cov.shape[-1] == cov.shape[-2] # symmetrical
if upper is None:
upper = torch.full_like(mean, float('inf'))
if lower is None:
lower = torch.full_like(mean, -float('inf'))
assert lower.shape == upper.shape == mean.shape
is_cens_up = torch.isfinite(upper)
is_cens_lo = torch.isfinite(lower)
if not is_cens_up.any() and not is_cens_lo.any():
return mean, cov
F1, F2 = _F1F2(mean, cov, lower, upper)
std = torch.diagonal(cov, dim1=-2, dim2=-1).sqrt()
sqrt_pi = pi ** .5
# prob censoring:
if probs is None:
prob_lo, prob_up = tobit_probs(mean=mean,
cov=cov,
lower=lower,
upper=upper)
else:
prob_lo, prob_up = probs
# adjust mean:
lower_adj = torch.zeros_like(mean)
lower_adj[is_cens_lo] = prob_lo[is_cens_lo] * lower[is_cens_lo]
upper_adj = torch.zeros_like(mean)
upper_adj[is_cens_up] = prob_up[is_cens_up] * upper[is_cens_up]
mean_if_uncens = mean + (sqrt(2. / pi) * F1) * std
mean_uncens_adj = (1. - prob_up - prob_lo) * mean_if_uncens
mean_adj = mean_uncens_adj + upper_adj + lower_adj
# adjust cov:
diag_adj = torch.zeros_like(mean)
for m in range(mean.shape[-1]):
diag_adj[..., m] = (1. + 2. / sqrt_pi * F2[..., m] - 2. / pi * (F1[..., m] ** 2)) * cov[..., m, m]
cov_adj = torch.diag_embed(diag_adj)
return mean_adj, cov_adj
示例20
def _update_group(self,
obs: Tensor,
group_idx: Union[slice, Sequence[int]],
which_valid: Union[slice, Sequence[int]],
lower: Optional[Tensor] = None,
upper: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
# indices:
idx_2d = bmat_idx(group_idx, which_valid)
idx_3d = bmat_idx(group_idx, which_valid, which_valid)
# observed values, censoring limits
obs = obs[idx_2d]
if lower is None:
lower = torch.full_like(obs, -float('inf'))
else:
lower = lower[idx_2d]
if torch.isnan(lower).any():
raise ValueError("NaNs not allowed in `lower`")
if upper is None:
upper = torch.full_like(obs, float('inf'))
else:
upper = upper[idx_2d]
if torch.isnan(upper).any():
raise ValueError("NaNs not allowed in `upper`")
if (lower == upper).any():
raise RuntimeError("lower cannot == upper")
# subset belief / design-mats:
means = self.means[group_idx]
covs = self.covs[group_idx]
R = self.R[idx_3d]
H = self.H[idx_2d]
measured_means = H.matmul(means.unsqueeze(-1)).squeeze(-1)
# calculate censoring fx:
prob_lo, prob_up = tobit_probs(mean=measured_means,
cov=R,
lower=lower,
upper=upper)
prob_obs = torch.diag_embed(1 - prob_up - prob_lo)
mm_adj, R_adj = tobit_adjustment(mean=measured_means,
cov=R,
lower=lower,
upper=upper,
probs=(prob_lo, prob_up))
# kalman gain:
K = self.kalman_gain(covariance=covs, H=H, R_adjusted=R_adj, prob_obs=prob_obs)
# update
means_new = self.mean_update(mean=means, K=K, residuals=obs - mm_adj)
covs_new = self.covariance_update(covariance=covs, K=K, H=H, prob_obs=prob_obs)
return means_new, covs_new
示例21
def test_transformed_posterior(self):
for dtype in (torch.float, torch.double):
for m in (1, 2):
shape = torch.Size([3, m])
mean = torch.rand(shape, dtype=dtype, device=self.device)
variance = 1 + torch.rand(shape, dtype=dtype, device=self.device)
if m == 1:
covar = torch.diag_embed(variance.squeeze(-1))
mvn = MultivariateNormal(mean.squeeze(-1), lazify(covar))
else:
covar = torch.diag_embed(variance.view(*variance.shape[:-2], -1))
mvn = MultitaskMultivariateNormal(mean, lazify(covar))
p_base = GPyTorchPosterior(mvn=mvn)
p_tf = TransformedPosterior( # dummy transforms
posterior=p_base,
sample_transform=lambda s: s + 2,
mean_transform=lambda m, v: 2 * m + v,
variance_transform=lambda m, v: m + 2 * v,
)
# mean, variance
self.assertEqual(p_tf.device.type, self.device.type)
self.assertTrue(p_tf.dtype == dtype)
self.assertEqual(p_tf.event_shape, shape)
self.assertTrue(torch.equal(p_tf.mean, 2 * mean + variance))
self.assertTrue(torch.equal(p_tf.variance, mean + 2 * variance))
# rsample
samples = p_tf.rsample()
self.assertEqual(samples.shape, torch.Size([1]) + shape)
samples = p_tf.rsample(sample_shape=torch.Size([4]))
self.assertEqual(samples.shape, torch.Size([4]) + shape)
samples2 = p_tf.rsample(sample_shape=torch.Size([4, 2]))
self.assertEqual(samples2.shape, torch.Size([4, 2]) + shape)
# rsample w/ base samples
base_samples = torch.randn(4, *shape, device=self.device, dtype=dtype)
# incompatible shapes
with self.assertRaises(RuntimeError):
p_tf.rsample(
sample_shape=torch.Size([3]), base_samples=base_samples
)
# make sure sample transform is applied correctly
samples_base = p_base.rsample(
sample_shape=torch.Size([4]), base_samples=base_samples
)
samples_tf = p_tf.rsample(
sample_shape=torch.Size([4]), base_samples=base_samples
)
self.assertTrue(torch.equal(samples_tf, samples_base + 2))
# check error handling
p_tf_2 = TransformedPosterior(
posterior=p_base, sample_transform=lambda s: s + 2
)
with self.assertRaises(NotImplementedError):
p_tf_2.mean
with self.assertRaises(NotImplementedError):
p_tf_2.variance
示例22
def test_GPyTorchPosterior_Multitask(self):
for dtype in (torch.float, torch.double):
mean = torch.rand(3, 2, dtype=dtype, device=self.device)
variance = 1 + torch.rand(3, 2, dtype=dtype, device=self.device)
covar = variance.view(-1).diag()
mvn = MultitaskMultivariateNormal(mean, lazify(covar))
posterior = GPyTorchPosterior(mvn=mvn)
# basics
self.assertEqual(posterior.device.type, self.device.type)
self.assertTrue(posterior.dtype == dtype)
self.assertEqual(posterior.event_shape, torch.Size([3, 2]))
self.assertTrue(torch.equal(posterior.mean, mean))
self.assertTrue(torch.equal(posterior.variance, variance))
# rsample
samples = posterior.rsample(sample_shape=torch.Size([4]))
self.assertEqual(samples.shape, torch.Size([4, 3, 2]))
samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 2]))
# rsample w/ base samples
base_samples = torch.randn(4, 3, 2, device=self.device, dtype=dtype)
samples_b1 = posterior.rsample(
sample_shape=torch.Size([4]), base_samples=base_samples
)
samples_b2 = posterior.rsample(
sample_shape=torch.Size([4]), base_samples=base_samples
)
self.assertTrue(torch.allclose(samples_b1, samples_b2))
base_samples2 = torch.randn(4, 2, 3, 2, device=self.device, dtype=dtype)
samples2_b1 = posterior.rsample(
sample_shape=torch.Size([4, 2]), base_samples=base_samples2
)
samples2_b2 = posterior.rsample(
sample_shape=torch.Size([4, 2]), base_samples=base_samples2
)
self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
# collapse_batch_dims
b_mean = torch.rand(2, 3, 2, dtype=dtype, device=self.device)
b_variance = 1 + torch.rand(2, 3, 2, dtype=dtype, device=self.device)
b_covar = torch.diag_embed(b_variance.view(2, 6))
b_mvn = MultitaskMultivariateNormal(b_mean, lazify(b_covar))
b_posterior = GPyTorchPosterior(mvn=b_mvn)
b_base_samples = torch.randn(4, 1, 3, 2, device=self.device, dtype=dtype)
b_samples = b_posterior.rsample(
sample_shape=torch.Size([4]), base_samples=b_base_samples
)
self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2]))
示例23
def _get_test_posterior(
batch_shape: torch.Size,
q: int = 1,
m: int = 1,
interleaved: bool = True,
lazy: bool = False,
independent: bool = False,
**tkwargs
) -> GPyTorchPosterior:
r"""Generate a Posterior for testing purposes.
Args:
batch_shape: The batch shape of the data.
q: The number of candidates
m: The number of outputs.
interleaved: A boolean indicating the format of the
MultitaskMultivariateNormal
lazy: A boolean indicating if the posterior should be lazy
indepedent: A boolean indicating whether the outputs are independent
tkwargs: `device` and `dtype` tensor constructor kwargs.
"""
if independent:
mvns = []
for _ in range(m):
mean = torch.rand(*batch_shape, q, **tkwargs)
a = torch.rand(*batch_shape, q, q, **tkwargs)
covar = a @ a.transpose(-1, -2)
flat_diag = torch.rand(*batch_shape, q, **tkwargs)
covar = covar + torch.diag_embed(flat_diag)
mvns.append(MultivariateNormal(mean, covar))
mtmvn = MultitaskMultivariateNormal.from_independent_mvns(mvns)
else:
mean = torch.rand(*batch_shape, q, m, **tkwargs)
a = torch.rand(*batch_shape, q * m, q * m, **tkwargs)
covar = a @ a.transpose(-1, -2)
flat_diag = torch.rand(*batch_shape, q * m, **tkwargs)
if lazy:
covar = AddedDiagLazyTensor(covar, DiagLazyTensor(flat_diag))
else:
covar = covar + torch.diag_embed(flat_diag)
mtmvn = MultitaskMultivariateNormal(mean, covar, interleaved=interleaved)
return GPyTorchPosterior(mtmvn)