Python源码示例:torch.pinverse()

示例1
def get_inverse_filters(self):
        fourier_basis = self._get_fft_basis()
        inverse_filters = torch.pinverse(
            fourier_basis.unsqueeze(0)).squeeze(0)
        return nn.Parameter(inverse_filters, requires_grad=self.requires_grad) 
示例2
def _evaluate_sample(self, sample: LogSample) -> Optional[EstimatorSampleResult]:
        log_slot_expects = sample.log_slot_item_expectations(sample.context.slots)
        if log_slot_expects is None:
            logger.warning("Log slot distribution not available")
            return None
        tgt_slot_expects = sample.tgt_slot_expectations(sample.context.slots)
        if tgt_slot_expects is None:
            logger.warning("Target slot distribution not available")
            return None
        log_indicator = log_slot_expects.values_tensor(self._device)
        tgt_indicator = tgt_slot_expects.values_tensor(self._device)
        lm = len(sample.context.slots) * len(sample.items)
        gamma = torch.as_tensor(
            np.linalg.pinv(
                torch.mm(
                    log_indicator.view((lm, 1)), log_indicator.view((1, lm))
                ).numpy()
            )
        )
        # torch.pinverse is not very stable
        # gamma = torch.pinverse(
        #     torch.mm(log_indicator.view((lm, 1)), log_indicator.view((1, lm)))
        # )
        ones = sample.log_slate.one_hots(sample.items, self._device)
        weight = self._weight_clamper(
            torch.mm(tgt_indicator.view((1, lm)), torch.mm(gamma, ones.view((lm, 1))))
        ).item()
        return EstimatorSampleResult(
            sample.log_reward,
            sample.log_reward * weight,
            sample.ground_truth_reward,
            weight,
        )

    # pyre-fixme[14]: `evaluate` overrides method defined in `Estimator` inconsistently. 
示例3
def fit(self):
        if self.readout_training in {'gd', 'svd'}:
            return

        if self.readout_training == 'cholesky':
            W = torch.solve(self.XTy,
                           self.XTX + self.lambda_reg * torch.eye(
                               self.XTX.size(0), device=self.XTX.device))[0].t()
            self.XTX = None
            self.XTy = None

            self.readout.bias = nn.Parameter(W[:, 0])
            self.readout.weight = nn.Parameter(W[:, 1:])
        elif self.readout_training == 'inv':
            I = (self.lambda_reg * torch.eye(self.XTX.size(0))).to(
                self.XTX.device)
            A = self.XTX + I

            if torch.det(A) != 0:
                W = torch.mm(torch.inverse(A), self.XTy).t()
            else:
                pinv = torch.pinverse(A)
                W = torch.mm(pinv, self.XTy).t()

            self.readout.bias = nn.Parameter(W[:, 0])
            self.readout.weight = nn.Parameter(W[:, 1:])

            self.XTX = None
            self.XTy = None 
示例4
def compute_filter_pinv(self, filters):
        """ Computes pseudo inverse filterbank of given filters."""
        scale = self.filterbank.stride / self.filterbank.kernel_size
        shape = filters.shape
        ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
        # Compensate for the overlap-add.
        return ifilt * scale