Python源码示例:torch.fft()
示例1
def pad_rfft3(f, onesided=True):
"""
padded batch real fft
:param f: tensor of shape [..., res0, res1, res2]
"""
n0, n1, n2 = f.shape[-3:]
h0, h1, h2 = int(n0/2), int(n1/2), int(n2/2)
F2 = torch.rfft(f, signal_ndim=1, onesided=onesided) # [..., res0, res1, res2/2+1, 2]
F2[..., h2, :] = 0
F1 = torch.fft(F2.transpose(-3,-2), signal_ndim=1)
F1[..., h1,:] = 0
F1 = F1.transpose(-2,-3)
F0 = torch.fft(F1.transpose(-4,-2), signal_ndim=1)
F0[..., h0,:] = 0
F0 = F0.transpose(-2,-4)
return F0
示例2
def pad_fft2(f):
"""
padded batch real fft
:param f: tensor of shape [..., res0, res1]
"""
n0, n1 = f.shape[-2:]
h0, h1 = int(n0/2), int(n1/2)
# turn f into complex signal
f = torch.stack((f, torch.zeros_like(f)), dim=-1) # [..., res0, res1, 2]
F1 = torch.fft(f, signal_ndim=1) # [..., res0, res1, 2]
F1[..., h1,:] = 0 # [..., res0, res1, 2]
F0 = torch.fft(F1.transpose(-3,-2), signal_ndim=1)
F0[..., h0,:] = 0
F0 = F0.transpose(-2,-3)
return F0
示例3
def rfftfreqs(res, dtype=torch.float32, exact=True):
"""
Helper function to return frequency tensors
:param res: n_dims int tuple of number of frequency modes
:return: frequency tensor of shape [dim, res, res, res/2+1]
"""
# print("res",res)
n_dims = len(res)
freqs = []
for dim in range(n_dims - 1):
r_ = res[dim]
freq = np.fft.fftfreq(r_, d=1/r_)
freqs.append(torch.tensor(freq, dtype=dtype))
r_ = res[-1]
if exact:
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
else:
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
omega = torch.meshgrid(freqs)
omega = list(omega)
omega = torch.stack(omega, dim=0)
# print("omega.shape",omega.shape)
return omega
示例4
def fft2(data):
"""
Apply centered 2 dimensional Fast Fourier Transform.
Args:
data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
-3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
assumed to be batch dimensions.
Returns:
torch.Tensor: The FFT of the input.
"""
assert data.size(-1) == 2
data = ifftshift(data, dim=(-3, -2))
data = torch.fft(data, 2, normalized=True)
data = fftshift(data, dim=(-3, -2))
return data
示例5
def fft2(data):
"""
Apply centered 2 dimensional Fast Fourier Transform.
Args:
data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
-3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
assumed to be batch dimensions.
Returns:
torch.Tensor: The FFT of the input.
"""
assert data.size(-1) == 2
data = ifftshift(data, dim=(-3, -2))
data = torch.fft(data, 2, normalized=True)
data = fftshift(data, dim=(-3, -2))
return data
示例6
def test_butterfly_fft():
# DFT matrix for n = 4
size = 4
DFT = torch.fft(real_to_complex(torch.eye(size)), 1)
P = real_to_complex(torch.tensor([[1., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.]]))
M0 = Butterfly(size,
diagonal=2,
complex=True,
diag=torch.tensor([[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]], requires_grad=True),
subdiag=torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True),
superdiag=torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True))
M1 = Butterfly(size,
diagonal=1,
complex=True,
diag=torch.tensor([[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]], requires_grad=True),
subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True),
superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True))
assert torch.allclose(complex_matmul(M0.matrix(), complex_matmul(M1.matrix(), P)), DFT)
br_perm = torch.tensor(bitreversal_permutation(size))
assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix())[:, br_perm], DFT)
D = complex_matmul(DFT, P.transpose(0, 1))
assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix()), D)
示例7
def forward(self, x):
bsn = 1
batchSize, dim, h, w = x.data.shape
x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim) # batchsize,h, w, dim,
y = torch.ones(batchSize, self.output_dim, device=x.device)
for img in range(batchSize // bsn):
segLen = bsn * h * w
upper = batchSize * h * w
interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
batch_x = x_flat[interLarge, :]
sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)
sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)
Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])
tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]
y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)
y = self._signed_sqrt(y)
y = self._l2norm(y)
return y
示例8
def get_uperleft_denominator(img, kernel):
'''
img: HxWxC
kernel: hxw
denominator: HxWx1
upperleft: HxWxC
'''
V = psf2otf(kernel, img.shape[:2])
denominator = np.expand_dims(np.abs(V)**2, axis=2)
upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
return upperleft, denominator
示例9
def fft(t):
return torch.fft(t, 2)
示例10
def otf2psf(otf, outsize=None):
insize = np.array(otf.shape)
psf = np.fft.ifftn(otf, axes=(0, 1))
for axis, axis_size in enumerate(insize):
psf = np.roll(psf, np.floor(axis_size / 2).astype(int), axis=axis)
if type(outsize) != type(None):
insize = np.array(otf.shape)
outsize = np.array(outsize)
n = max(np.size(outsize), np.size(insize))
# outsize = postpad(outsize(:), n, 1);
# insize = postpad(insize(:) , n, 1);
colvec_out = outsize.flatten().reshape((np.size(outsize), 1))
colvec_in = insize.flatten().reshape((np.size(insize), 1))
outsize = np.pad(colvec_out, ((0, max(0, n - np.size(colvec_out))), (0, 0)), mode="constant")
insize = np.pad(colvec_in, ((0, max(0, n - np.size(colvec_in))), (0, 0)), mode="constant")
pad = (insize - outsize) / 2
if np.any(pad < 0):
print("otf2psf error: OUTSIZE must be smaller than or equal than OTF size")
prepad = np.floor(pad)
postpad = np.ceil(pad)
dims_start = prepad.astype(int)
dims_end = (insize - postpad).astype(int)
for i in range(len(dims_start.shape)):
psf = np.take(psf, range(dims_start[i][0], dims_end[i][0]), axis=i)
n_ops = np.sum(otf.size * np.log2(otf.shape))
psf = np.real_if_close(psf, tol=n_ops)
return psf
# psf2otf copied/modified from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py
示例11
def fft(t):
return torch.fft(t, 2)
示例12
def fft(t):
# Complex-to-complex Discrete Fourier Transform
return torch.fft(t, 2)
示例13
def forward(self, head, rel, tail):
h_e, r_e, t_e = self.embed(head, rel, tail)
r_e = F.normalize(r_e, p=2, dim=-1)
h_e = torch.stack((h_e, torch.zeros_like(h_e)), -1)
t_e = torch.stack((t_e, torch.zeros_like(t_e)), -1)
e, _ = torch.unbind(torch.ifft(torch.conj(torch.fft(h_e, 1)) * torch.fft(t_e, 1), 1), -1)
return -F.sigmoid(torch.sum(r_e * e, 1))
示例14
def pad_irfft3(F):
"""
padded batch inverse real fft
:param f: tensor of shape [..., res0, res1, res2/2+1, 2]
"""
res = F.shape[-3]
f0 = torch.ifft(F.transpose(-4,-2), signal_ndim=1).transpose(-2,-4)
f1 = torch.ifft(f0.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
f2 = torch.irfft(f1, signal_ndim=1, signal_sizes=[res]) # [..., res0, res1, res2]
return f2
示例15
def pad_ifft2(F):
"""
padded batch inverse real fft
:param f: tensor of shape [..., res0, res1, res2/2+1, 2]
"""
f0 = torch.ifft(F.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
f1 = torch.ifft(f0, signal_ndim=1)
return f2
示例16
def fftshift(x, dim=None):
"""
Similar to np.fft.fftshift but applies to PyTorch Tensors
"""
if dim is None:
dim = tuple(range(x.dim()))
shift = [dim // 2 for dim in x.shape]
elif isinstance(dim, int):
shift = x.shape[dim] // 2
else:
shift = [x.shape[i] // 2 for i in dim]
return roll(x, shift, dim)
示例17
def ifftshift(x, dim=None):
"""
Similar to np.fft.ifftshift but applies to PyTorch Tensors
"""
if dim is None:
dim = tuple(range(x.dim()))
shift = [(dim + 1) // 2 for dim in x.shape]
elif isinstance(dim, int):
shift = (x.shape[dim] + 1) // 2
else:
shift = [(x.shape[i] + 1) // 2 for i in dim]
return roll(x, shift, dim)
示例18
def fft2_np(data):
"""
Numpy version of fft2
"""
data = np.fft.ifftshift(data, axes=(-2, -1))
data = np.fft.fft2(data, norm="ortho")
data = np.fft.fftshift(data, axes=(-2, -1))
return data
示例19
def fftshift(x, dim=None):
"""
Similar to np.fft.fftshift but applies to PyTorch Tensors
"""
if dim is None:
dim = tuple(range(x.dim()))
shift = [dim // 2 for dim in x.shape]
elif isinstance(dim, int):
shift = x.shape[dim] // 2
else:
shift = [x.shape[i] // 2 for i in dim]
return roll(x, shift, dim)
示例20
def ifftshift(x, dim=None):
"""
Similar to np.fft.ifftshift but applies to PyTorch Tensors
"""
if dim is None:
dim = tuple(range(x.dim()))
shift = [(dim + 1) // 2 for dim in x.shape] #TODO: looks wrong
elif isinstance(dim, int):
shift = (x.shape[dim] + 1) // 2
else:
shift = [(x.shape[i] + 1) // 2 for i in dim]
return roll(x, shift, dim)
示例21
def forward(self, bottom):
batch_size, _, height, width = bottom.size()
bottom_flat = bottom.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1)
sketch_1 = bottom_flat.mm(self.sparse_sketch_matrix1)
sketch_2 = bottom_flat.mm(self.sparse_sketch_matrix2)
im_zeros_1 = torch.zeros(sketch_1.size()).to(sketch_1.device)
im_zeros_2 = torch.zeros(sketch_2.size()).to(sketch_2.device)
fft1 = torch.fft(torch.cat([sketch_1.unsqueeze(-1), im_zeros_1.unsqueeze(-1)], dim=-1), 1)
fft2 = torch.fft(torch.cat([sketch_2.unsqueeze(-1), im_zeros_2.unsqueeze(-1)], dim=-1), 1)
fft_product_real = fft1[..., 0].mul(fft2[..., 0]) - fft1[..., 1].mul(fft2[..., 1])
fft_product_imag = fft1[..., 0].mul(fft2[..., 1]) + fft1[..., 1].mul(fft2[..., 0])
cbp_flat = torch.ifft(torch.cat([
fft_product_real.unsqueeze(-1),
fft_product_imag.unsqueeze(-1)],
dim=-1), 1)[..., 0]
cbp = cbp_flat.view(batch_size, height, width, self.output_dim)
if self.sum_pool:
cbp = cbp.sum(dim=[1, 2])
return cbp
示例22
def fft_test():
# DFT matrix for n = 4
size = 4
DFT = torch.fft(real_to_complex(torch.eye(size)), 1)
P = torch.stack((torch.tensor([[1., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.]]),
torch.zeros((size, size))), dim=-1)
M0 = Butterfly(size,
diagonal=2,
complex=True,
diag=torch.tensor([[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]], requires_grad=True),
subdiag=torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True),
superdiag=torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True))
M1 = Butterfly(size,
diagonal=1,
complex=True,
diag=torch.tensor([[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]], requires_grad=True),
subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True),
superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True))
assert torch.allclose(complex_matmul(M0.matrix(), complex_matmul(M1.matrix(), P)), DFT)
br_perm = torch.tensor(bitreversal_permutation(size))
assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix())[:, br_perm], DFT)
D = complex_matmul(DFT, P.transpose(0, 1))
assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix()), D)
示例23
def _setup(self, config):
size = config['size']
torch.manual_seed(config['seed'])
self.model = ButterflyProduct(size=size, complex=True, fixed_order=True)
self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
self.n_steps_per_epoch = config['n_steps_per_epoch']
self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
self.br_perm = torch.tensor(bitreversal_permutation(size))
示例24
def _setup(self, config):
size = config['size']
torch.manual_seed(config['seed'])
self.model = ButterflyProduct(size=size, complex=True, fixed_order=False)
self.semantic_loss_weight = config['semantic_loss_weight']
self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
self.n_steps_per_epoch = config['n_steps_per_epoch']
self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
self.br_perm = torch.tensor(bitreversal_permutation(size))
示例25
def _setup(self, config):
size = config['size']
torch.manual_seed(config['seed'])
self.model = ButterflyProduct(size=size, complex=True, fixed_order=False, softmax_fn='sparsemax')
self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
self.n_steps_per_epoch = config['n_steps_per_epoch']
self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
self.br_perm = torch.tensor(bitreversal_permutation(size))
示例26
def _setup(self, config):
torch.manual_seed(config['seed'])
self.model = ButterflyProduct(size=config['size'],
complex=True,
fixed_order=config['fixed_order'],
softmax_fn=config['softmax_fn'])
if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
self.semantic_loss_weight = config['semantic_loss_weight']
self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
self.n_steps_per_epoch = config['n_steps_per_epoch']
size = config['size']
self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
self.br_perm = torch.tensor(bitreversal_permutation(size))
# br_perm = bitreversal_permutation(size)
# br_reverse = torch.tensor(list(br_perm[::-1]))
# br_reverse = torch.cat((torch.tensor(list(br_perm[:size//2][::-1])), torch.tensor(list(br_perm[size//2:][::-1]))))
# Same as [6, 2, 4, 0, 7, 3, 5, 1], which is [0, 1]^4 * [0, 2, 1, 3]^2 * [6, 4, 2, 0, 7, 5, 3, 1]
# br_reverse = torch.cat((torch.tensor(list(br_perm[:size//4][::-1])), torch.tensor(list(br_perm[size//4:size//2][::-1])), torch.tensor(list(br_perm[size//2:3*size//4][::-1])), torch.tensor(list(br_perm[3*size//4:][::-1]))))
# self.br_perm = br_reverse
# self.br_perm = torch.tensor([0, 7, 4, 3, 2, 5, 6, 1]) # Doesn't work
# self.br_perm = torch.tensor([7, 3, 0, 4, 2, 6, 5, 1]) # Doesn't work
# self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 7, 3]) # This works, [0, 1]^4 * [2, 0, 3, 1]^2 * [0, 2, 4, 6, 1, 3, 5, 7] or [1, 0]^4 * [0, 2, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7]
# self.br_perm = torch.tensor([4, 0, 2, 6, 5, 1, 3, 7]) # Doesn't work, [0, 1]^4 * [2, 0, 1, 3]^2 * [0, 2, 4, 6, 1, 3, 5, 7]
# self.br_perm = torch.tensor([1, 5, 3, 7, 0, 4, 2, 6]) # This works, [0, 1]^4 * [4, 6, 5, 7, 0, 4, 2, 6]
# self.br_perm = torch.tensor([4, 0, 6, 2, 5, 1, 3, 7]) # Doesn't work
# self.br_perm = torch.tensor([4, 0, 6, 2, 1, 5, 3, 7]) # Doesn't work
# self.br_perm = torch.tensor([0, 4, 6, 2, 1, 5, 7, 3]) # Doesn't work
# self.br_perm = torch.tensor([4, 1, 6, 2, 5, 0, 7, 3]) # This works, since it's just swapping 0 and 1
# self.br_perm = torch.tensor([5, 1, 6, 2, 4, 0, 7, 3]) # This works, since it's swapping 4 and 5
示例27
def _setup(self, config):
torch.manual_seed(config['seed'])
self.model = nn.Sequential(
BlockPermProduct(size=config['size'], complex=True, share_logit=False),
Block2x2DiagProduct(size=config['size'], complex=True)
)
self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
self.n_steps_per_epoch = config['n_steps_per_epoch']
size = config['size']
self.target_matrix = torch.fft(real_to_complex(torch.eye(size)))
# self.target_matrix = size * torch.ifft(real_to_complex(torch.eye(size)))
self.input = real_to_complex(torch.eye(size))
示例28
def _setup(self, config):
torch.manual_seed(config['seed'])
self.model = nn.Sequential(
Block2x2DiagProduct(size=config['size'], complex=True, decreasing_size=False),
BlockPermProduct(size=config['size'], complex=True, share_logit=False, increasing_size=True),
)
self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
self.n_steps_per_epoch = config['n_steps_per_epoch']
size = config['size']
self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
self.input = real_to_complex(torch.eye(size))
示例29
def _setup(self, config):
torch.manual_seed(config['seed'])
self.model = ButterflyProduct(size=config['size'],
complex=True,
fixed_order=config['fixed_order'],
softmax_fn=config['softmax_fn'],
learn_perm=True)
if (not config['fixed_order']) and config['softmax_fn'] == 'softmax':
self.semantic_loss_weight = config['semantic_loss_weight']
self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
self.n_steps_per_epoch = config['n_steps_per_epoch']
size = config['size']
self.target_matrix = torch.fft(real_to_complex(torch.eye(size)), 1)
示例30
def test_block2x2diagproduct():
# Factorization of the DFT matrix
size = 4
model = Block2x2DiagProduct(size, complex=True)
model.factors[1].ABCD = nn.Parameter(torch.tensor([[[[1.0, 0.0]], [[1.0, 0.0]]], [[[1.0, 0.0]], [[-1.0, 0.0]]]]))
model.factors[0].ABCD = nn.Parameter(torch.tensor([[[[1.0, 0.0],
[1.0, 0.0]],
[[1.0, 0.0],
[0.0, -1.0]]],
[[[1.0, 0.0],
[1.0, 0.0]],
[[-1.0, 0.0],
[0.0, 1.0]]]]))
input = torch.stack((torch.eye(size), torch.zeros(size, size)), dim=-1)
assert torch.allclose(model(input[:, [0, 2, 1, 3]]), torch.fft(input, 1))