Skip to content
Snippets Groups Projects
Unverified Commit 0dc559f0 authored by Jan Buethe's avatar Jan Buethe
Browse files

added some bwe-related stuff

parent 5667867f
No related branches found
No related tags found
No related merge requests found
import torch
import scipy.signal
from utils.layers.fir import FIR
class TDLowpass(torch.nn.Module):
def __init__(self, numtaps, cutoff, power=2):
super().__init__()
self.b = scipy.signal.firwin(numtaps, cutoff)
self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
self.power = power
def forward(self, y_true, y_pred):
assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
diff = y_true - y_pred
diff_lp = torch.nn.functional.conv1d(diff, self.weight)
loss = torch.mean(torch.abs(diff_lp ** self.power))
return loss, diff_lp
def get_freqz(self):
freq, response = scipy.signal.freqz(self.b)
return freq, response
\ No newline at end of file
import argparse
from scipy.io import wavfile
import torch
import numpy as np
from utils.layers.silk_upsampler import SilkUpsampler
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="input wave file")
parser.add_argument("output", type=str, help="output wave file")
if __name__ == "__main__":
args = parser.parse_args()
fs, x = wavfile.read(args.input)
# being lazy for now
assert fs == 16000 and x.dtype == np.int16
x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1)
upsampler = SilkUpsampler()
y = upsampler(x)
y = y.squeeze().numpy().astype(np.int16)
wavfile.write(args.output, 48000, y[13:])
\ No newline at end of file
import numpy as np
import scipy.signal
import torch
from torch import nn
import torch.nn.functional as F
class FIR(nn.Module):
def __init__(self, numtaps, bands, desired, fs=2):
super().__init__()
if numtaps % 2 == 0:
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
numtaps += 1
a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
self.weight = torch.from_numpy(a.astype(np.float32))
def forward(self, x):
num_channels = x.size(1)
weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
y = F.conv1d(x, weight, groups=num_channels)
return y
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment