Python源码示例:torch.fft()

示例1
def pad_rfft3(f, onesided=True):
    """
    padded batch real fft
    :param f: tensor of shape [..., res0, res1, res2]
    """
    n0, n1, n2 = f.shape[-3:]
    h0, h1, h2 = int(n0/2), int(n1/2), int(n2/2)

    F2 = torch.rfft(f, signal_ndim=1, onesided=onesided) # [..., res0, res1, res2/2+1, 2]
    F2[..., h2, :] = 0

    F1 = torch.fft(F2.transpose(-3,-2), signal_ndim=1)
    F1[..., h1,:] = 0
    F1 = F1.transpose(-2,-3)

    F0 = torch.fft(F1.transpose(-4,-2), signal_ndim=1)
    F0[..., h0,:] = 0
    F0 = F0.transpose(-2,-4)
    return F0 
示例2
def pad_fft2(f):
    """
    padded batch real fft
    :param f: tensor of shape [..., res0, res1]
    """
    n0, n1 = f.shape[-2:]
    h0, h1 = int(n0/2), int(n1/2)
    # turn f into complex signal
    f = torch.stack((f, torch.zeros_like(f)), dim=-1) # [..., res0, res1, 2]

    F1 = torch.fft(f, signal_ndim=1) # [..., res0, res1, 2]
    F1[..., h1,:] = 0 # [..., res0, res1, 2]

    F0 = torch.fft(F1.transpose(-3,-2), signal_ndim=1)
    F0[..., h0,:] = 0
    F0 = F0.transpose(-2,-3)
    return F0 
示例3
def rfftfreqs(res, dtype=torch.float32, exact=True):
    """
    Helper function to return frequency tensors
    :param res: n_dims int tuple of number of frequency modes
    :return: frequency tensor of shape [dim, res, res, res/2+1]
    """
    # print("res",res)
    n_dims = len(res)
    freqs = []
    for dim in range(n_dims - 1):
        r_ = res[dim]
        freq = np.fft.fftfreq(r_, d=1/r_)
        freqs.append(torch.tensor(freq, dtype=dtype))
    r_ = res[-1]
    if exact:
        freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
    else:
        freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
    omega = torch.meshgrid(freqs)
    omega = list(omega)
    omega = torch.stack(omega, dim=0)

    # print("omega.shape",omega.shape)
    return omega 
示例4
def fft2(data):
    """
    Apply centered 2 dimensional Fast Fourier Transform.

    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.

    Returns:
        torch.Tensor: The FFT of the input.
    """
    assert data.size(-1) == 2
    data = ifftshift(data, dim=(-3, -2))
    data = torch.fft(data, 2, normalized=True)
    data = fftshift(data, dim=(-3, -2))
    return data 
示例5
def fft2(data):
    """
    Apply centered 2 dimensional Fast Fourier Transform.

    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.

    Returns:
        torch.Tensor: The FFT of the input.
    """
    assert data.size(-1) == 2
    data = ifftshift(data, dim=(-3, -2))
    data = torch.fft(data, 2, normalized=True)
    data = fftshift(data, dim=(-3, -2))
    return data 
示例6
def test_butterfly_fft():
    # DFT matrix for n = 4
    size = 4
    DFT = torch.fft(real_to_complex(torch.eye(size)), 1)
    P = real_to_complex(torch.tensor([[1., 0., 0., 0.],
                                      [0., 0., 1., 0.],
                                      [0., 1., 0., 0.],
                                      [0., 0., 0., 1.]]))
    M0 = Butterfly(size,
                   diagonal=2,
                   complex=True,
                   diag=torch.tensor([[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]], requires_grad=True),
                   subdiag=torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True),
                   superdiag=torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True))
    M1 = Butterfly(size,
                   diagonal=1,
                   complex=True,
                   diag=torch.tensor([[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]], requires_grad=True),
                   subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True),
                   superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True))
    assert torch.allclose(complex_matmul(M0.matrix(), complex_matmul(M1.matrix(), P)), DFT)
    br_perm = torch.tensor(bitreversal_permutation(size))
    assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix())[:, br_perm], DFT)
    D = complex_matmul(DFT, P.transpose(0, 1))
    assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix()), D) 
示例7
def forward(self, x):
         bsn = 1
         batchSize, dim, h, w = x.data.shape
         x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim)  # batchsize,h, w, dim,
         y = torch.ones(batchSize, self.output_dim, device=x.device)

         for img in range(batchSize // bsn):
             segLen = bsn * h * w
             upper = batchSize * h * w
             interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
             interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
             batch_x = x_flat[interLarge, :]

             sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
             sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)

             sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
             sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)

             Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
             Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])

             tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]

             y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)

         y = self._signed_sqrt(y)
         y = self._l2norm(y)
         return y 
示例8
def get_uperleft_denominator(img, kernel):
    '''
    img: HxWxC
    kernel: hxw
    denominator: HxWx1
    upperleft: HxWxC
    '''
    V = psf2otf(kernel, img.shape[:2])
    denominator = np.expand_dims(np.abs(V)**2, axis=2)
    upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
    return upperleft, denominator 
示例9
def fft(t):
    return torch.fft(t, 2) 
示例10
def otf2psf(otf, outsize=None):
    insize = np.array(otf.shape)
    psf = np.fft.ifftn(otf, axes=(0, 1))
    for axis, axis_size in enumerate(insize):
        psf = np.roll(psf, np.floor(axis_size / 2).astype(int), axis=axis)
    if type(outsize) != type(None):
        insize = np.array(otf.shape)
        outsize = np.array(outsize)
        n = max(np.size(outsize), np.size(insize))
        # outsize = postpad(outsize(:), n, 1);
        # insize = postpad(insize(:) , n, 1);
        colvec_out = outsize.flatten().reshape((np.size(outsize), 1))
        colvec_in = insize.flatten().reshape((np.size(insize), 1))
        outsize = np.pad(colvec_out, ((0, max(0, n - np.size(colvec_out))), (0, 0)), mode="constant")
        insize = np.pad(colvec_in, ((0, max(0, n - np.size(colvec_in))), (0, 0)), mode="constant")

        pad = (insize - outsize) / 2
        if np.any(pad < 0):
            print("otf2psf error: OUTSIZE must be smaller than or equal than OTF size")
        prepad = np.floor(pad)
        postpad = np.ceil(pad)
        dims_start = prepad.astype(int)
        dims_end = (insize - postpad).astype(int)
        for i in range(len(dims_start.shape)):
            psf = np.take(psf, range(dims_start[i][0], dims_end[i][0]), axis=i)
    n_ops = np.sum(otf.size * np.log2(otf.shape))
    psf = np.real_if_close(psf, tol=n_ops)
    return psf


# psf2otf copied/modified from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py 
示例11
def fft(t):
    return torch.fft(t, 2) 
示例12
def fft(t):
    # Complex-to-complex Discrete Fourier Transform
    return torch.fft(t, 2) 
示例13
def forward(self, head, rel, tail):
        h_e, r_e, t_e = self.embed(head, rel, tail)
        r_e = F.normalize(r_e, p=2, dim=-1)
        h_e = torch.stack((h_e, torch.zeros_like(h_e)), -1)
        t_e = torch.stack((t_e, torch.zeros_like(t_e)), -1)
        e, _ = torch.unbind(torch.ifft(torch.conj(torch.fft(h_e, 1)) * torch.fft(t_e, 1), 1), -1)
        return -F.sigmoid(torch.sum(r_e * e, 1)) 
示例14
def pad_irfft3(F):
    """
    padded batch inverse real fft
    :param f: tensor of shape [..., res0, res1, res2/2+1, 2]
    """
    res = F.shape[-3]
    f0 = torch.ifft(F.transpose(-4,-2), signal_ndim=1).transpose(-2,-4)
    f1 = torch.ifft(f0.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
    f2 = torch.irfft(f1, signal_ndim=1, signal_sizes=[res]) # [..., res0, res1, res2]
    return f2 
示例15
def pad_ifft2(F):
    """
    padded batch inverse real fft
    :param f: tensor of shape [..., res0, res1, res2/2+1, 2]
    """
    f0 = torch.ifft(F.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
    f1 = torch.ifft(f0, signal_ndim=1)
    return f2 
示例16
def fftshift(x, dim=None):
    """
    Similar to np.fft.fftshift but applies to PyTorch Tensors
    """
    if dim is None:
        dim = tuple(range(x.dim()))
        shift = [dim // 2 for dim in x.shape]
    elif isinstance(dim, int):
        shift = x.shape[dim] // 2
    else:
        shift = [x.shape[i] // 2 for i in dim]
    return roll(x, shift, dim) 
示例17
def ifftshift(x, dim=None):
    """
    Similar to np.fft.ifftshift but applies to PyTorch Tensors
    """
    if dim is None:
        dim = tuple(range(x.dim()))
        shift = [(dim + 1) // 2 for dim in x.shape]
    elif isinstance(dim, int):
        shift = (x.shape[dim] + 1) // 2
    else:
        shift = [(x.shape[i] + 1) // 2 for i in dim]
    return roll(x, shift, dim) 
示例18
def fft2_np(data):
    """
    Numpy version of fft2
    """
    data = np.fft.ifftshift(data, axes=(-2, -1))
    data = np.fft.fft2(data, norm="ortho")
    data = np.fft.fftshift(data, axes=(-2, -1))
    return data 
示例19
def fftshift(x, dim=None):
    """
    Similar to np.fft.fftshift but applies to PyTorch Tensors
    """
    if dim is None:
        dim = tuple(range(x.dim()))
        shift = [dim // 2 for dim in x.shape]
    elif isinstance(dim, int):
        shift = x.shape[dim] // 2
    else:
        shift = [x.shape[i] // 2 for i in dim]
    return roll(x, shift, dim) 
示例20
def ifftshift(x, dim=None):
    """
    Similar to np.fft.ifftshift but applies to PyTorch Tensors
    """
    if dim is None:
        dim = tuple(range(x.dim()))
        shift = [(dim + 1) // 2 for dim in x.shape] #TODO: looks wrong
    elif isinstance(dim, int):
        shift = (x.shape[dim] + 1) // 2
    else:
        shift = [(x.shape[i] + 1) // 2 for i in dim]
    return roll(x, shift, dim) 
示例21
def forward(self, bottom):

        batch_size, _, height, width = bottom.size()

        bottom_flat = bottom.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1)

        sketch_1 = bottom_flat.mm(self.sparse_sketch_matrix1)
        sketch_2 = bottom_flat.mm(self.sparse_sketch_matrix2)

        im_zeros_1 = torch.zeros(sketch_1.size()).to(sketch_1.device)
        im_zeros_2 = torch.zeros(sketch_2.size()).to(sketch_2.device)
        fft1 = torch.fft(torch.cat([sketch_1.unsqueeze(-1), im_zeros_1.unsqueeze(-1)], dim=-1), 1)
        fft2 = torch.fft(torch.cat([sketch_2.unsqueeze(-1), im_zeros_2.unsqueeze(-1)], dim=-1), 1)

        fft_product_real = fft1[..., 0].mul(fft2[..., 0]) - fft1[..., 1].mul(fft2[..., 1])
        fft_product_imag = fft1[..., 0].mul(fft2[..., 1]) + fft1[..., 1].mul(fft2[..., 0])

        cbp_flat = torch.ifft(torch.cat([
            fft_product_real.unsqueeze(-1),
            fft_product_imag.unsqueeze(-1)],
            dim=-1), 1)[..., 0]

        cbp = cbp_flat.view(batch_size, height, width, self.output_dim)

        if self.sum_pool:
            cbp = cbp.sum(dim=[1, 2])

        return cbp 
示例22
def fft_test():
    # DFT matrix for n = 4
    size = 4
    DFT = torch.fft(real_to_complex(torch.eye(size)), 1)
    P = torch.stack((torch.tensor([[1., 0., 0., 0.],
                                   [0., 0., 1., 0.],
                                   [0., 1., 0., 0.],
                                   [0., 0., 0., 1.]]),
                     torch.zeros((size, size))), dim=-1)
    M0 = Butterfly(size,
                   diagonal=2,
                   complex=True,
                   diag=torch.tensor([[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]], requires_grad=True),
                   subdiag=torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True),
                   superdiag=torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True))
    M1 = Butterfly(size,
                   diagonal=1,
                   complex=True,
                   diag=torch.tensor([[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]], requires_grad=True),
                   subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True),
                   superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True))
    assert torch.allclose(complex_matmul(M0.matrix(), complex_matmul(M1.matrix(), P)), DFT)
    br_perm = torch.tensor(bitreversal_permutation(size))
    assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix())[:, br_perm], DFT)
    D = complex_matmul(DFT, P.transpose(0, 1))
    assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix()), D) 
示例23
def _setup(self, config):
        size = config['size']
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=size, complex=True, fixed_order=True)
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size)) 
示例24
def _setup(self, config):
        size = config['size']
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=size, complex=True, fixed_order=False)
        self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size)) 
示例25
def _setup(self, config):
        size = config['size']
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=size, complex=True, fixed_order=False, softmax_fn='sparsemax')
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size)) 
示例26
def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=config['size'],
                                      complex=True,
                                      fixed_order=config['fixed_order'],
                                      softmax_fn=config['softmax_fn'])
        if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
            self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.br_perm = torch.tensor(bitreversal_permutation(size))
        # br_perm = bitreversal_permutation(size)
        # br_reverse = torch.tensor(list(br_perm[::-1]))
        # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//2][::-1])), torch.tensor(list(br_perm[size//2:][::-1]))))
        # Same as [6, 2, 4, 0, 7, 3, 5, 1], which is [0, 1]^4 * [0, 2, 1, 3]^2 * [6, 4, 2, 0, 7, 5, 3, 1]
        # br_reverse = torch.cat((torch.tensor(list(br_perm[:size//4][::-1])), torch.tensor(list(br_perm[size//4:size//2][::-1])), torch.tensor(list(br_perm[size//2:3*size//4][::-1])), torch.tensor(list(br_perm[3*size//4:][::-1]))))
        # self.br_perm = br_reverse
        # self.br_perm = torch.tensor([0, 7, 4, 3, 2, 5, 6, 1])  # Doesn't work
        # self.br_perm = torch.tensor([7, 3, 0, 4, 2, 6, 5, 1])  # Doesn't work
        # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 7, 3])  # This works, [0, 1]^4 * [2, 0, 3, 1]^2 * [0, 2, 4, 6, 1, 3, 5, 7] or [1, 0]^4 * [0, 2, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7]
        # self.br_perm = torch.tensor([4, 0, 2, 6, 5, 1, 3, 7])  # Doesn't work, [0, 1]^4 * [2, 0, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7]
        # self.br_perm = torch.tensor([1, 5, 3, 7, 0, 4, 2, 6])  # This works, [0, 1]^4 * [4, 6, 5, 7, 0, 4, 2, 6]
        # self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 3, 7])  # Doesn't work
        # self.br_perm = torch.tensor([4, 0, 6, 2, 1, 5, 3, 7])  # Doesn't work
        # self.br_perm = torch.tensor([0, 4, 6, 2, 1, 5, 7, 3])  # Doesn't work
        # self.br_perm = torch.tensor([4, 1, 6, 2, 5, 0, 7, 3])  # This works, since it's just swapping 0 and 1
        # self.br_perm = torch.tensor([5, 1, 6, 2, 4, 0, 7, 3])  # This works, since it's swapping 4 and 5 
示例27
def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = nn.Sequential(
            BlockPermProduct(size=config['size'], complex=True, share_logit=False),
            Block2x2DiagProduct(size=config['size'], complex=True)
        )
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)))
        # self.target_matrix = size * torch.ifft(real_to_complex(torch.eye(size)))
        self.input = real_to_complex(torch.eye(size)) 
示例28
def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = nn.Sequential(
            Block2x2DiagProduct(size=config['size'], complex=True, decreasing_size=False),
            BlockPermProduct(size=config['size'], complex=True, share_logit=False, increasing_size=True),
        )
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
        self.input = real_to_complex(torch.eye(size)) 
示例29
def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = ButterflyProduct(size=config['size'],
                                      complex=True,
                                      fixed_order=config['fixed_order'],
                                      softmax_fn=config['softmax_fn'],
                                      learn_perm=True)
        if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
            self.semantic_loss_weight = config['semantic_loss_weight']
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1) 
示例30
def test_block2x2diagproduct():
    # Factorization of the DFT matrix
    size = 4
    model = Block2x2DiagProduct(size, complex=True)
    model.factors[1].ABCD = nn.Parameter(torch.tensor([[[[1.0, 0.0]], [[1.0, 0.0]]], [[[1.0, 0.0]], [[-1.0, 0.0]]]]))
    model.factors[0].ABCD = nn.Parameter(torch.tensor([[[[1.0, 0.0],
                                                         [1.0, 0.0]],
                                                        [[1.0, 0.0],
                                                         [0.0, -1.0]]],
                                                       [[[1.0, 0.0],
                                                         [1.0, 0.0]],
                                                        [[-1.0, 0.0],
                                                         [0.0, 1.0]]]]))
    input = torch.stack((torch.eye(size), torch.zeros(size, size)), dim=-1)
    assert torch.allclose(model(input[:, [0, 2, 1, 3]]), torch.fft(input, 1))