Python源码示例:torch.pinverse()
示例1
def get_inverse_filters(self):
fourier_basis = self._get_fft_basis()
inverse_filters = torch.pinverse(
fourier_basis.unsqueeze(0)).squeeze(0)
return nn.Parameter(inverse_filters, requires_grad=self.requires_grad)
示例2
def _evaluate_sample(self, sample: LogSample) -> Optional[EstimatorSampleResult]:
log_slot_expects = sample.log_slot_item_expectations(sample.context.slots)
if log_slot_expects is None:
logger.warning("Log slot distribution not available")
return None
tgt_slot_expects = sample.tgt_slot_expectations(sample.context.slots)
if tgt_slot_expects is None:
logger.warning("Target slot distribution not available")
return None
log_indicator = log_slot_expects.values_tensor(self._device)
tgt_indicator = tgt_slot_expects.values_tensor(self._device)
lm = len(sample.context.slots) * len(sample.items)
gamma = torch.as_tensor(
np.linalg.pinv(
torch.mm(
log_indicator.view((lm, 1)), log_indicator.view((1, lm))
).numpy()
)
)
# torch.pinverse is not very stable
# gamma = torch.pinverse(
# torch.mm(log_indicator.view((lm, 1)), log_indicator.view((1, lm)))
# )
ones = sample.log_slate.one_hots(sample.items, self._device)
weight = self._weight_clamper(
torch.mm(tgt_indicator.view((1, lm)), torch.mm(gamma, ones.view((lm, 1))))
).item()
return EstimatorSampleResult(
sample.log_reward,
sample.log_reward * weight,
sample.ground_truth_reward,
weight,
)
# pyre-fixme[14]: `evaluate` overrides method defined in `Estimator` inconsistently.
示例3
def fit(self):
if self.readout_training in {'gd', 'svd'}:
return
if self.readout_training == 'cholesky':
W = torch.solve(self.XTy,
self.XTX + self.lambda_reg * torch.eye(
self.XTX.size(0), device=self.XTX.device))[0].t()
self.XTX = None
self.XTy = None
self.readout.bias = nn.Parameter(W[:, 0])
self.readout.weight = nn.Parameter(W[:, 1:])
elif self.readout_training == 'inv':
I = (self.lambda_reg * torch.eye(self.XTX.size(0))).to(
self.XTX.device)
A = self.XTX + I
if torch.det(A) != 0:
W = torch.mm(torch.inverse(A), self.XTy).t()
else:
pinv = torch.pinverse(A)
W = torch.mm(pinv, self.XTy).t()
self.readout.bias = nn.Parameter(W[:, 0])
self.readout.weight = nn.Parameter(W[:, 1:])
self.XTX = None
self.XTy = None
示例4
def compute_filter_pinv(self, filters):
""" Computes pseudo inverse filterbank of given filters."""
scale = self.filterbank.stride / self.filterbank.kernel_size
shape = filters.shape
ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
# Compensate for the overlap-add.
return ifilt * scale