Python源码示例:torch.stft()
示例1
def __init__(self,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2.,
normalized: bool = False,
wkwargs: Optional[dict] = None) -> None:
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
# number of frequecies due to onesided=True in torch.stft
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.pad = pad
self.power = power
self.normalized = normalized
示例2
def test_istft_requires_nola(self):
stft = torch.zeros((3, 5, 2))
kwargs_ok = {
'n_fft': 4,
'win_length': 4,
'window': torch.ones(4),
}
kwargs_not_ok = {
'n_fft': 4,
'win_length': 4,
'window': torch.zeros(4),
}
# A window of ones meets NOLA but a window of zeros does not. This should
# throw an error.
torchaudio.functional.istft(stft, **kwargs_ok)
self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok)
示例3
def _test_istft_of_sine(self, amplitude, L, n):
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype())
sound = amplitude * torch.sin(2 * math.pi / L * x * n)
# stft = torch.stft(sound, L, hop_length=L, win_length=L,
# window=torch.ones(L), center=False, normalized=False)
stft = torch.zeros((L // 2 + 1, 2, 2))
stft_largest_val = (amplitude * L) / 2.0
if n < stft.size(0):
stft[n, :, 1] = -stft_largest_val
if 0 <= L - n < stft.size(0):
# symmetric about L // 2
stft[L - n, :, 1] = stft_largest_val
estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L,
window=torch.ones(L), center=False, normalized=False)
# There is a larger error due to the scaling of amplitude
_compare_estimate(sound, estimate, atol=1e-3)
示例4
def compute_torch_stft(audio, descriptor):
name, *args = descriptor.split("_")
n_fft, hop_size, *rest = args
n_fft = int(n_fft)
hop_size = int(hop_size)
stft = torch.stft(
audio,
n_fft=n_fft,
hop_length=hop_size,
window=torch.hann_window(n_fft, device=audio.device)
)
stft = torch.sqrt((stft ** 2).sum(-1))
return stft
示例5
def forward(self, x):
if self.preemp is not None:
x = x.unsqueeze(1)
x = self.preemp(x)
x = x.squeeze(1)
stft = torch.stft(x,
self.win_length,
self.hop_length,
fft_size=self.n_fft,
window=self.win)
real = stft[:, :, :, 0]
im = stft[:, :, :, 1]
spec = torch.sqrt(torch.pow(real, 2) + torch.pow(im, 2))
# convert linear spec to mel
mel = torch.matmul(spec, self.mel_basis)
# convert to db
mel = _amp_to_db(mel) - hparams.ref_level_db
return _normalize(mel)
示例6
def __call__(self, wav):
with torch.no_grad():
# STFT
data = torch.stft(wav, n_fft=self.nfft, hop_length=self.window_shift,
win_length=self.window_size, window=self.window)
data /= self.window.pow(2).sum().sqrt_()
#mag = data.pow(2).sum(-1).log1p_()
#ang = torch.atan2(data[:, :, 1], data[:, :, 0])
## {mag, phase} x n_freq_bin x n_frame
#data = torch.cat([mag.unsqueeze_(0), ang.unsqueeze_(0)], dim=0)
## FxTx2 -> 2xFxT
data = data.transpose(1, 2).transpose(0, 1)
return data
# transformer: frame splitter
示例7
def forward(self, audio):
p = (self.n_fft - self.hop_length) // 2
audio = F.pad(audio, (p, p), "reflect").squeeze(1)
fft = torch.stft(
audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=False,
)
real_part, imag_part = fft.unbind(-1)
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
mel_output = torch.matmul(self.mel_basis, magnitude)
log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5))
return log_mel_spec
示例8
def forward(self, x):
"""
Input: (nb_samples, nb_channels, nb_timesteps)
Output:(nb_samples, nb_channels, nb_bins, nb_frames, 2)
"""
nb_samples, nb_channels, nb_timesteps = x.size()
# merge nb_samples and nb_channels for multichannel stft
x = x.reshape(nb_samples*nb_channels, -1)
# compute stft with parameters as close as possible scipy settings
stft_f = torch.stft(
x,
n_fft=self.n_fft, hop_length=self.n_hop,
window=self.window, center=self.center,
normalized=False, onesided=True,
pad_mode='reflect'
)
# reshape back to channel dimension
stft_f = stft_f.contiguous().view(
nb_samples, nb_channels, self.n_fft // 2 + 1, -1, 2
)
return stft_f
示例9
def test_istft(self):
stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])
self.assert_batch_consistencies(F.istft, stft, n_fft=4, length=4)
示例10
def test_batch_TimeStretch(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
kwargs = {
'n_fft': 2048,
'hop_length': 512,
'win_length': 2048,
'window': torch.hann_window(2048),
'center': True,
'pad_mode': 'reflect',
'normalized': True,
'onesided': True,
}
rate = 2
complex_specgrams = torch.stft(waveform, **kwargs)
# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
hop_length=512,
)(complex_specgrams).repeat(3, 1, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1))
self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
示例11
def test_istft_of_ones(self):
# stft = torch.stft(torch.ones(4), 4)
stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])
estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
_compare_estimate(torch.ones(4), estimate)
示例12
def test_istft_of_zeros(self):
# stft = torch.stft(torch.zeros(4), 4)
stft = torch.zeros((3, 5, 2))
estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
_compare_estimate(torch.zeros(4), estimate)
示例13
def test_istft_requires_overlap_windows(self):
# the window is size 1 but it hops 20 so there is a gap which throw an error
stft = torch.zeros((3, 5, 2))
self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4,
hop_length=20, win_length=1, window=torch.ones(1))
示例14
def forward(self, x):
x = torch.stft(x, self.n_fft, **self.stft_kwargs).norm(dim=-1, p=2)
x = self.f2m(x.permute(0, 2, 1))
if self.use_cuda_kernel:
x, ls = pcen_cuda_kernel(x, self.eps, self.s, self.alpha, self.delta, self.r, self.trainable, self.last_state, self.empty)
else:
x, ls = pcen(x, self.eps, self.s, self.alpha, self.delta, self.r, self.training and self.trainable, self.last_state, self.empty)
self.last_state = ls.detach()
self.empty = False
return x
示例15
def __call__(self, pkg, cached_file=None):
pkg = format_package(pkg)
wav = pkg['chunk']
max_frames = wav.size(0) // self.hop
if cached_file is not None:
# load pre-computed data
X = torch.load(cached_file)
beg_i = pkg['chunk_beg_i'] // self.hop
end_i = pkg['chunk_end_i'] // self.hop
X = X[:, beg_i:end_i]
pkg['lps'] = X
else:
#print ('Chunks wav shape is {}'.format(wav.shape))
wav = wav.to(self.device)
X = torch.stft(wav, self.n_fft,
self.hop, self.win)
X = torch.norm(X, 2, dim=2).cpu()[:, :max_frames]
X = 10 * torch.log10(X ** 2 + 10e-20).cpu()
if self.der_order > 0 :
deltas=[X]
for n in range(1,self.der_order+1):
deltas.append(librosa.feature.delta(X.numpy(),order=n))
X=torch.from_numpy(np.concatenate(deltas))
pkg[self.name] = X
# Overwrite resolution to hop length
pkg['dec_resolution'] = self.hop
return pkg
示例16
def __call__(self, pkg):
pkg = format_package(pkg)
wav = pkg['chunk']
max_frames = wav.size(0) // self.hop
wav = wav.to(self.device)
X = torch.stft(wav, self.n_fft,
self.hop, self.win)
X = torch.norm(X, 2, dim=2).cpu()[:, :max_frames]
pkg['lps'] = 10 * torch.log10(X ** 2 + 10e-20).cpu()
return pkg
示例17
def is_stft(descriptor):
return descriptor.startswith("stft")
示例18
def get_mel(self, x):
stft = torch.stft(
input=x,
n_fft=self.hp.audio.n_fft,
hop_length=self.hp.audio.hop_length,
win_length=self.hp.audio.win_length,
window=self.window
)
mag = torch.norm(stft, p=2, dim=-1)
melspectrogram = torch.matmul(self.mel_basis, mag)
return melspectrogram
示例19
def stft(y, scale='linear'):
D = torch.stft(y, n_fft=1024, hop_length=256, win_length=1024, window=torch.hann_window(1024).cuda())
D = torch.sqrt(D.pow(2).sum(-1) + 1e-10)
# D = torch.sqrt(torch.clamp(D.pow(2).sum(-1), min=1e-10))
if scale == 'linear':
return D
elif scale == 'log':
S = 2 * torch.log(torch.clamp(D, 1e-10, float("inf")))
return S
else:
pass
# STFT code is adapted from: https://github.com/pseeth/pytorch-stft
示例20
def _power_loss(self, p_y, t_y):
fft_orig = torch.stft(t_y.reshape(t_y.shape[0]), n_fft=512,
window=torch.hann_window(window_length=512).to(device))
fft_pred = torch.stft(p_y.reshape(p_y.shape[0]), n_fft=512,
window=torch.hann_window(window_length=512).to(device))
real_orig = fft_orig[:, :, 0]
im_org = fft_orig[:, :, 1]
power_orig = torch.sqrt(torch.pow(real_orig, 2) + torch.pow(im_org, 2))
real_pred = fft_pred[:, :, 0]
im_pred = fft_pred[:, :, 1]
power_pred = torch.sqrt(torch.pow(real_pred, 2) + torch.pow(im_pred, 2))
return torch.sum(torch.pow(torch.norm(torch.abs(power_pred) - torch.abs(power_orig), p=2, dim=1), 2)) / (
power_pred.shape[0] * power_pred.shape[1])
示例21
def forward(self, x, stfts_orig):
stfts = []
# First compute multiple STFT for x
for i, scale in enumerate(self.scales):
cur_fft = torch.stft(x, n_fft=scale, window=self.windows[i], hop_length=int((1-self.overlap)*scale), center=False)
stfts.append(amp(cur_fft))
# Compute loss
lin_loss = sum([torch.mean(abs(stfts_orig[i][j] - stfts[i][j])) for j in range(len(stfts[i])) for i in range(len(stfts))])
log_loss = sum([torch.mean(abs(torch.log(stfts_orig[i][j] + 1e-4) - torch.log(stfts[i][j] + 1e-4))) for j in range(len(stfts[i])) for i in range(len(stfts))])
return lin_loss + log_loss
示例22
def forward(self, x):
stfts = []
for i, scale in enumerate(self.scales):
cur_fft = torch.stft(x, n_fft=scale, window=self.windows[i], hop_length=int((1-self.overlap)*scale), center=False)
stfts.append(amp(cur_fft))
if (self.reshape):
stft_tab = []
for b in range(x.shape[0]):
cur_fft = []
for s, _ in enumerate(self.scales):
cur_fft.append(stfts[s][b])
stft_tab.append(cur_fft)
stfts = stft_tab
return stfts
示例23
def compute_stft(audio, n_fft=1024, win_length=1024, hop_length=256):
"""
Computes STFT transformation of given audio
Args:
audio (Tensor): B x T, batch of audio
Returns:
mag (Tensor): STFT magnitudes
real (Tensor): Real part of STFT transformation result
im (Tensor): Imagine part of STFT transformation result
"""
win = torch.hann_window(win_length).cuda()
# add some padding because torch 4.0 doesn't
signal_dim = audio.dim()
extended_shape = [1] * (3 - signal_dim) + list(audio.size())
# pad = int(self.n_fft // 2)
pad = win_length
audio = F.pad(audio.view(extended_shape), (pad, pad), 'constant')
audio = audio.view(audio.shape[-signal_dim:])
stft = torch.stft(audio, win_length, hop_length, fft_size=n_fft, window=win)
real = stft[:, :, :, 0]
im = stft[:, :, :, 1]
power = torch.sqrt(torch.pow(real, 2) + torch.pow(im, 2))
return power, real, im
示例24
def forward(self, x):
if self.preemp is not None:
x = x.unsqueeze(1)
# conv and remove last padding
x = self.preemp(x)[:, :, :-1]
x = x.squeeze(1)
# center=True
# torch 0.4 doesnt support like librosa
signal_dim = x.dim()
extended_shape = [1] * (3 - signal_dim) + list(x.size())
# pad = int(self.n_fft // 2)
pad = self.win_length
x = F.pad(x.view(extended_shape), (pad, pad), 'constant')
x = x.view(x.shape[-signal_dim:])
stft = torch.stft(x,
self.win_length,
self.hop_length,
window=self.win,
fft_size=self.n_fft)
real = stft[:, :, :, 0]
im = stft[:, :, :, 1]
p = torch.sqrt(torch.pow(real, 2) + torch.pow(im, 2))
# convert volume to db
spec = _amp_to_db(p) - hparams.ref_level_db
return spec, p
示例25
def __call__(self, signal):
spectrogram = torch.stft(
torch.FloatTensor(signal),
self.n_fft,
hop_length=self.hop_length,
win_length=self.n_fft,
window=torch.hamming_window(self.n_fft),
center=False,
normalized=False,
onesided=True
)
spectrogram = (spectrogram[:, :, 0].pow(2) + spectrogram[:, :, 1].pow(2)).pow(0.5)
spectrogram = np.log1p(spectrogram.numpy())
return spectrogram
示例26
def _stft(self, data: torch.Tensor, n_fft: int, hop_length: int):
win_length = n_fft
window = self._stft_window
stft = torch.stft(
data,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
center=True,
pad_mode='reflect',
normalized=False,
)
return stft
示例27
def stft(y):
D = torch.stft(y, n_fft=1024, hop_length=256, win_length=1024, window=torch.hann_window(1024).cuda())
D = torch.sqrt(D.pow(2).sum(-1) + 1e-10)
S = 2 * torch.log(torch.clamp(D, 1e-10, float("inf")))
return D, S
示例28
def stft(y, scale='linear'):
D = torch.stft(y, n_fft=1024, hop_length=256, win_length=1024)#, window=torch.hann_window(1024).cuda())
D = torch.sqrt(D.pow(2).sum(-1) + 1e-10)
# D = torch.sqrt(torch.clamp(D.pow(2).sum(-1), min=1e-10))
if scale == 'linear':
return D
elif scale == 'log':
S = 2 * torch.log(torch.clamp(D, 1e-10, float("inf")))
return S
else:
pass
示例29
def spectrogram(
waveform: Tensor,
pad: int,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
power: Optional[float],
normalized: bool
) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.
Args:
waveform (Tensor): Tensor of audio of dimension (..., time)
pad (int): Two sided padding of signal
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
hop_length (int): Length of hop between STFT windows
win_length (int): Window size
power (float or None): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead.
normalized (bool): Whether to normalize by magnitude after stft
Returns:
Tensor: Dimension (..., freq, time), freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
if pad > 0:
# TODO add "with torch.no_grad():" back when JIT supports it
waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = torch.stft(
waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
)
# unpack batch
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])
if normalized:
spec_f /= window.pow(2.).sum().sqrt()
if power is not None:
spec_f = complex_norm(spec_f, power=power)
return spec_f
示例30
def create_fb_matrix(
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int,
norm: Optional[str] = None
) -> Tensor:
r"""Create a frequency bin conversion matrix.
Args:
n_freqs (int): Number of frequencies to highlight/apply
f_min (float): Minimum frequency (Hz)
f_max (float): Maximum frequency (Hz)
n_mels (int): Number of mel filterbanks
sample_rate (int): Sample rate of the audio waveform
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
meaning number of frequencies to highlight/apply to x the number of filterbanks.
Each column is a filterbank so that assuming there is a matrix A of
size (..., ``n_freqs``), the applied result would be
``A * create_fb_matrix(A.size(-1), ...)``.
"""
if norm is not None and norm != "slaney":
raise ValueError("norm must be one of None or 'slaney'")
# freq bins
# Equivalent filterbank construction by Librosa
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
# calculate mel freq bins
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
# mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
# create overlapping triangles
zero = torch.zeros(1)
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
if norm is not None and norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
fb *= enorm.unsqueeze(0)
return fb