Python源码示例:torch.ifft()
示例1
def ifft2(data):
"""
Apply centered 2-dimensional Inverse Fast Fourier Transform.
Args:
data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
-3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
assumed to be batch dimensions.
Returns:
torch.Tensor: The IFFT of the input.
"""
assert data.size(-1) == 2
data = ifftshift(data, dim=(-3, -2))
data = torch.ifft(data, 2, normalized=True)
data = fftshift(data, dim=(-3, -2))
return data
示例2
def ifft2(data):
"""
Apply centered 2-dimensional Inverse Fast Fourier Transform.
Args:
data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
-3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
assumed to be batch dimensions.
Returns:
torch.Tensor: The IFFT of the input.
"""
assert data.size(-1) == 2
data = ifftshift(data, dim=(-3, -2))
data = torch.ifft(data, 2, normalized=True)
data = fftshift(data, dim=(-3, -2))
return data
示例3
def forward(self, x):
bsn = 1
batchSize, dim, h, w = x.data.shape
x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim) # batchsize,h, w, dim,
y = torch.ones(batchSize, self.output_dim, device=x.device)
for img in range(batchSize // bsn):
segLen = bsn * h * w
upper = batchSize * h * w
interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
batch_x = x_flat[interLarge, :]
sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)
sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)
Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])
tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]
y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)
y = self._signed_sqrt(y)
y = self._l2norm(y)
return y
示例4
def ifft(t):
return torch.ifft(t, 2)
示例5
def ifft(t):
return torch.ifft(t, 2)
示例6
def ifft(t):
# Complex-to-complex Inverse Discrete Fourier Transform
return torch.ifft(t, 2)
示例7
def pad_irfft3(F):
"""
padded batch inverse real fft
:param f: tensor of shape [..., res0, res1, res2/2+1, 2]
"""
res = F.shape[-3]
f0 = torch.ifft(F.transpose(-4,-2), signal_ndim=1).transpose(-2,-4)
f1 = torch.ifft(f0.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
f2 = torch.irfft(f1, signal_ndim=1, signal_sizes=[res]) # [..., res0, res1, res2]
return f2
示例8
def pad_ifft2(F):
"""
padded batch inverse real fft
:param f: tensor of shape [..., res0, res1, res2/2+1, 2]
"""
f0 = torch.ifft(F.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
f1 = torch.ifft(f0, signal_ndim=1)
return f2
示例9
def so3_ifft(x, for_grad=False, b_out=None):
'''
:param x: [l * m * n, ..., complex]
'''
assert x.size(-1) == 2
nspec = x.size(0)
b_in = round((3 / 4 * nspec) ** (1 / 3))
assert nspec == b_in * (4 * b_in ** 2 - 1) // 3
if b_out is None:
b_out = b_in
batch_size = x.size()[1:-1]
x = x.view(nspec, -1, 2) # [l * m * n, batch, complex] (nspec, nbatch, 2)
'''
:param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2)
:return: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_out, 2 b_out, 2 b_out, 2)
'''
nbatch = x.size(1)
wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) # [beta, l * m * n] (2 * b_out, nspec)
output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2))
if x.is_cuda and x.dtype == torch.float32:
cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=False, device=x.device.index)
cuda_kernel(x, wigner, output) # [batch, beta, m, n, complex]
else:
output.fill_(0)
for l in range(min(b_in, b_out)):
s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2)
out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1)))
l1 = min(l, b_out - 1) # if b_out < b_in
output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1]
if l > 0:
output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1]
output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l]
output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l]
output = torch.ifft(output, 2) * output.size(-2) ** 2 # [batch, beta, alpha, gamma, complex]
output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2 * b_out, 2)
return output
示例10
def backward(self, grad_output): # pylint: disable=W
# ifft of grad_output is not necessarily real, therefore we cannot use rifft
return so3_ifft(grad_output, for_grad=True, b_out=self.b_in)[..., 0], None
示例11
def forward(self, bottom):
batch_size, _, height, width = bottom.size()
bottom_flat = bottom.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1)
sketch_1 = bottom_flat.mm(self.sparse_sketch_matrix1)
sketch_2 = bottom_flat.mm(self.sparse_sketch_matrix2)
im_zeros_1 = torch.zeros(sketch_1.size()).to(sketch_1.device)
im_zeros_2 = torch.zeros(sketch_2.size()).to(sketch_2.device)
fft1 = torch.fft(torch.cat([sketch_1.unsqueeze(-1), im_zeros_1.unsqueeze(-1)], dim=-1), 1)
fft2 = torch.fft(torch.cat([sketch_2.unsqueeze(-1), im_zeros_2.unsqueeze(-1)], dim=-1), 1)
fft_product_real = fft1[..., 0].mul(fft2[..., 0]) - fft1[..., 1].mul(fft2[..., 1])
fft_product_imag = fft1[..., 0].mul(fft2[..., 1]) + fft1[..., 1].mul(fft2[..., 0])
cbp_flat = torch.ifft(torch.cat([
fft_product_real.unsqueeze(-1),
fft_product_imag.unsqueeze(-1)],
dim=-1), 1)[..., 0]
cbp = cbp_flat.view(batch_size, height, width, self.output_dim)
if self.sum_pool:
cbp = cbp.sum(dim=[1, 2])
return cbp
示例12
def _setup(self, config):
torch.manual_seed(config['seed'])
self.model = nn.Sequential(
BlockPermProduct(size=config['size'], complex=True, share_logit=False),
Block2x2DiagProduct(size=config['size'], complex=True)
)
self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
self.n_steps_per_epoch = config['n_steps_per_epoch']
size = config['size']
self.target_matrix = torch.fft(real_to_complex(torch.eye(size)))
# self.target_matrix = size * torch.ifft(real_to_complex(torch.eye(size)))
self.input = real_to_complex(torch.eye(size))
示例13
def forward(self, x):
bsn = 1
batchSize, dim, h, w = x.data.shape
x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim) # batchsize,h, w, dim,
y = torch.ones(batchSize, self.output_dim, device=x.device)
for img in range(batchSize // bsn):
segLen = bsn * h * w
upper = batchSize * h * w
interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
batch_x = x_flat[interLarge, :]
sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)
sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)
Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])
tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]
y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)
y = self._signed_sqrt(y)
y = self._l2norm(y)
return y
示例14
def fft(input, inverse=False):
"""Interface with torch FFT routines for 3D signals.
fft of a 3d signal
Example
-------
x = torch.randn(128, 32, 32, 32, 2)
x_fft = fft(x)
x_ifft = fft(x, inverse=True)
Parameters
----------
x : tensor
Complex input for the FFT.
inverse : bool
True for computing the inverse FFT.
Raises
------
TypeError
In the event that x does not have a final dimension 2 i.e. not
complex.
Returns
-------
output : tensor
Result of FFT or IFFT.
"""
if not _is_complex(input):
raise TypeError('The input should be complex (e.g. last dimension is 2)')
if inverse:
return torch.ifft(input, 3)
return torch.fft(input, 3)
示例15
def torch_ifft2(k, normalized=True):
""" ifft on last 2 dim """
kt = numpy_to_torch(k)
xt = torch.ifft(kt, 2, normalized)
return torch_to_complex_numpy(xt)
示例16
def torch_ifft2c(x, normalized=True):
""" ifft2 on last 2 dim """
x = np.fft.ifftshift(x, axes=(-2,-1))
xt = numpy_to_torch(x)
kt = torch.ifft(xt, 2, normalized=True)
k = torch_to_complex_numpy(kt)
return np.fft.fftshift(k, axes=(-2,-1))
示例17
def ifft2(data):
assert data.size(-1) == 2
data = torch.ifft(data, 2, normalized=True)
return data
示例18
def ifft2c(data):
"""
Apply centered 2-dimensional Inverse Fast Fourier Transform.
Args:
data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
-3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
assumed to be batch dimensions.
Returns:
torch.Tensor: The IFFT of the input.
"""
assert data.size(-1) == 2
data = ifftshift(data, dim=(-3, -2))
data = torch.ifft(data, 2, normalized=True)
data = fftshift(data, dim=(-3, -2))
return data
示例19
def fft_adj(x, ndim=2):
return torch.ifft(x, signal_ndim=ndim, normalized=True)
示例20
def fft_filter(x, kern, norm=None):
"""FFT-based filtering on a 2-size oversampled grid.
"""
im_size = torch.tensor(x.shape).to(torch.long)[3:]
grid_size = im_size * 2
# set up n-dimensional zero pad
pad_sizes = []
permute_dims = [0, 1]
inv_permute_dims = [0, 1, 2 + grid_size.shape[0]]
for i in range(grid_size.shape[0]):
pad_sizes.append(0)
pad_sizes.append(int(grid_size[-1 - i] - im_size[-1 - i]))
permute_dims.append(3 + i)
inv_permute_dims.append(2 + i)
permute_dims.append(2)
pad_sizes = tuple(pad_sizes)
permute_dims = tuple(permute_dims)
inv_permute_dims = tuple(inv_permute_dims)
# zero pad and fft
x = F.pad(x, pad_sizes)
x = x.permute(permute_dims)
x = torch.fft(x, grid_size.numel())
if norm == 'ortho':
x = x / torch.sqrt(torch.prod(grid_size.to(torch.double)))
x = x.permute(inv_permute_dims)
# apply the filter
x = complex_mult(x, kern, dim=2)
# inverse fft
x = x.permute(permute_dims)
x = torch.ifft(x, grid_size.numel())
x = x.permute(inv_permute_dims)
# crop to input size
crop_starts = tuple(np.array(x.shape).astype(np.int) * 0)
crop_ends = [x.shape[0], x.shape[1], x.shape[2]]
for dim in im_size:
crop_ends.append(int(dim))
x = x[tuple(map(slice, crop_starts, crop_ends))]
# scaling, assume user handled adjoint scaling with their kernel
if norm == 'ortho':
x = x / torch.sqrt(torch.prod(grid_size.to(torch.double)))
return x
示例21
def s2_ifft(x, for_grad=False, b_out=None):
'''
:param x: [l * m, ..., complex]
'''
assert x.size(-1) == 2
nspec = x.size(0)
b_in = round(nspec ** 0.5)
assert nspec == b_in ** 2
if b_out is None:
b_out = b_in
assert b_out >= b_in
batch_size = x.size()[1:-1]
x = x.view(nspec, -1, 2) # [l * m, batch, complex] (nspec, nbatch, 2)
'''
:param x: [l * m, batch, complex] (b_in**2, nbatch, 2)
:return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2)
'''
nbatch = x.size(1)
wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)
wigner = wigner.view(2 * b_out, -1) # [beta, l * m] (2 * b_out, nspec)
if x.is_cuda and x.dtype == torch.float32:
import s2cnn.utils.cuda as cuda_utils
cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=x.device.index)
stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2))
cuda_kernel(block=(1024, 1, 1),
grid=(cuda_utils.get_blocks(nbatch * (2 * b_out) ** 2, 1024), 1, 1),
args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()],
stream=stream)
# [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2)
else:
output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2))
for l in range(b_in):
s = slice(l ** 2, l ** 2 + 2 * l + 1)
out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s]))
output[:, :, :l + 1] += out[:, :, -l - 1:]
if l > 0:
output[:, :, -l:] += out[:, :, :l]
output = torch.ifft(output, 1) * output.size(-2) # [batch, beta, alpha, complex]
output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2)
return output
示例22
def so3_rifft(x, for_grad=False, b_out=None):
'''
:param x: [l * m * n, ..., complex]
'''
assert x.size(-1) == 2
nspec = x.size(0)
b_in = round((3 / 4 * nspec) ** (1 / 3))
assert nspec == b_in * (4 * b_in ** 2 - 1) // 3
if b_out is None:
b_out = b_in
batch_size = x.size()[1:-1]
x = x.view(nspec, -1, 2) # [l * m * n, batch, complex] (nspec, nbatch, 2)
'''
:param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2)
:return: [batch, beta, alpha, gamma] (nbatch, 2 b_out, 2 b_out, 2 b_out)
'''
nbatch = x.size(1)
wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) # [beta, l * m * n] (2 * b_out, nspec)
output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2))
if x.is_cuda and x.dtype == torch.float32:
cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=True, device=x.device.index)
cuda_kernel(x, wigner, output) # [batch, beta, m, n, complex]
else:
# TODO can be optimized knowing that the output is real, like in _setup_so3ifft_cuda_kernel(real_output=True)
output.fill_(0)
for l in range(min(b_in, b_out)):
s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2)
out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1)))
l1 = min(l, b_out - 1) # if b_out < b_in
output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1]
if l > 0:
output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1]
output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l]
output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l]
output = torch.ifft(output, 2) * output.size(-2) ** 2 # [batch, beta, alpha, gamma, complex]
output = output[..., 0] # [batch, beta, alpha, gamma]
output = output.contiguous()
output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2 * b_out)
return output