Python源码示例:torch.slogdet()
示例1
def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input: Tensor
input tensor [batch, N1, N2, ..., Nl, in_features]
mask: Tensor
mask tensor [batch, N1, N2, ...,Nl]
Returns: out: Tensor , logdet: Tensor
out: [batch, N1, N2, ..., in_features], the output of the flow
logdet: [batch], the log determinant of :math:`\partial output / \partial input`
"""
dim = input.dim()
# [batch, N1, N2, ..., in_features]
out = F.linear(input, self.weight)
_, logdet = torch.slogdet(self.weight)
if dim > 2:
num = mask.view(out.size(0), -1).sum(dim=1)
logdet = logdet * num
return out, logdet
示例2
def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input: Tensor
input tensor [batch, N1, N2, ..., Nl, in_features]
mask: Tensor
mask tensor [batch, N1, N2, ...,Nl]
Returns: out: Tensor , logdet: Tensor
out: [batch, N1, N2, ..., in_features], the output of the flow
logdet: [batch], the log determinant of :math:`\partial output / \partial input`
"""
dim = input.dim()
# [batch, N1, N2, ..., in_features]
out = F.linear(input, self.weight_inv)
_, logdet = torch.slogdet(self.weight_inv)
if dim > 2:
num = mask.view(out.size(0), -1).sum(dim=1)
logdet = logdet * num
return out, logdet
示例3
def logabsdet(x):
"""Returns the log absolute determinant of square matrix x."""
# Note: torch.logdet() only works for positive determinant.
_, res = torch.slogdet(x)
return res
示例4
def forward(self, inputs, cond_inputs=None, mode='direct'):
if mode == 'direct':
return inputs @ self.W, torch.slogdet(
self.W)[-1].unsqueeze(0).unsqueeze(0).repeat(
inputs.size(0), 1)
else:
return inputs @ torch.inverse(self.W), -torch.slogdet(
self.W)[-1].unsqueeze(0).unsqueeze(0).repeat(
inputs.size(0), 1)
示例5
def forward(self, x):
# x --> z
# torch.slogdet() is not stable
if self.train_sampling:
W = torch.inverse(self.weight.double()).float()
else:
W = self.weight
logdet = self.log_determinant(x, W)
kernel = W.view(*self.w_shape, 1, 1)
return F.conv2d(x, kernel), logdet
示例6
def backward(self, y: torch.tensor, x: torch.tensor=None, x_freqs: torch.tensor=None, require_log_probs=True, var=None, y_freqs=None):
# from other language to this language
x_prime = y.mm(self.W)
if require_log_probs:
assert x is not None, x_freqs is not None
log_probs = self.cal_mixture_of_gaussian_fix_var(x_prime, x, x_freqs, var, x_prime_freqs=y_freqs)
_, log_abs_det = torch.slogdet(self.W)
log_probs = log_probs + log_abs_det
else:
log_probs = torch.tensor(0)
return x_prime, log_probs
示例7
def forward(self, x, sldj, reverse=False):
ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)
if reverse:
weight = torch.inverse(self.weight.double()).float()
sldj = sldj - ldj
else:
weight = self.weight
sldj = sldj + ldj
weight = weight.view(self.num_channels, self.num_channels, 1, 1)
z = F.conv2d(x, weight)
return z, sldj
示例8
def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input: Tensor
input tensor [batch, N1, N2, ..., Nl, in_features]
mask: Tensor
mask tensor [batch, N1, N2, ...,Nl]
Returns: out: Tensor , logdet: Tensor
out: [batch, N1, N2, ..., in_features], the output of the flow
logdet: [batch], the log determinant of :math:`\partial output / \partial input`
"""
size = input.size()
dim = input.dim()
# [batch, N1, N2, ..., heads, in_features/ heads]
if self.type == 'A':
out = input.view(*size[:-1], self.heads, self.in_features // self.heads)
else:
out = input.view(*size[:-1], self.in_features // self.heads, self.heads).transpose(-2, -1)
out = F.linear(out, self.weight)
if self.type == 'B':
out = out.transpose(-2, -1).contiguous()
out = out.view(*size)
_, logdet = torch.slogdet(self.weight)
if dim > 2:
num = mask.view(size[0], -1).sum(dim=1) * self.heads
logdet = logdet * num
return out, logdet
示例9
def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input: Tensor
input tensor [batch, N1, N2, ..., Nl, in_features]
mask: Tensor
mask tensor [batch, N1, N2, ...,Nl]
Returns: out: Tensor , logdet: Tensor
out: [batch, N1, N2, ..., in_features], the output of the flow
logdet: [batch], the log determinant of :math:`\partial output / \partial input`
"""
size = input.size()
dim = input.dim()
# [batch, N1, N2, ..., heads, in_features/ heads]
if self.type == 'A':
out = input.view(*size[:-1], self.heads, self.in_features // self.heads)
else:
out = input.view(*size[:-1], self.in_features // self.heads, self.heads).transpose(-2, -1)
out = F.linear(out, self.weight_inv)
if self.type == 'B':
out = out.transpose(-2, -1).contiguous()
out = out.view(*size)
_, logdet = torch.slogdet(self.weight_inv)
if dim > 2:
num = mask.view(size[0], -1).sum(dim=1) * self.heads
logdet = logdet * num
return out, logdet
示例10
def get_weight(self, input, reverse):
w_shape = self.w_shape
if not self.LU:
pixels = thops.pixels(input)
dlogdet = torch.slogdet(self.weight)[1] * pixels
if not reverse:
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
else:
weight = torch.inverse(self.weight.double()).float()\
.view(w_shape[0], w_shape[1], 1, 1)
return weight, dlogdet
else:
self.p = self.p.to(input.device)
self.sign_s = self.sign_s.to(input.device)
self.l_mask = self.l_mask.to(input.device)
self.eye = self.eye.to(input.device)
l = self.l * self.l_mask + self.eye
u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
dlogdet = thops.sum(self.log_s) * thops.pixels(input)
if not reverse:
w = torch.matmul(self.p, torch.matmul(l, u))
else:
l = torch.inverse(l.double()).float()
u = torch.inverse(u.double()).float()
w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
return w.view(w_shape[0], w_shape[1], 1, 1), dlogdet
示例11
def get_parameters(self, x, inverse):
w_shape = self.w_shape
pixels = np.prod(x.size()[2:])
device = x.device
if not self.decomposed:
logdet_jacobian = torch.slogdet(self.weight.cpu())[1].to(device) * pixels
if not inverse:
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
else:
weight = torch.inverse(self.weight.double()).float().view(w_shape[0], w_shape[1], 1, 1)
return weight, logdet_jacobian
else:
self.p = self.p.to(device)
self.sign_s = self.sign_s.to(device)
self.l_mask = self.l_mask.to(device)
self.eye = self.eye.to(device)
l = self.l * self.l_mask + self.eye
u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
logdet_jacobian = torch.sum(self.log_s) * pixels
if not inverse:
w = torch.matmul(self.p, torch.matmul(l, u))
else:
l = torch.inverse(l.double()).float()
u = torch.inverse(u.double()).float()
w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
return w.view(w_shape[0], w_shape[1], 1, 1), logdet_jacobian
示例12
def forward(self, input):
_, _, height, width = input.shape
out = F.conv2d(input, self.weight)
logdet = (
height * width * torch.slogdet(self.weight.squeeze().double())[1].float()
)
return out, logdet