wave-u-net:
This commit is contained in:
parent
06afc95701
commit
48d5f9c21e
|
|
@ -1,5 +1,12 @@
|
||||||
|
from tkinter import wantobjects
|
||||||
|
import wave
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Optional, Union, List
|
||||||
|
|
||||||
|
from enhancer.models.model import Model
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
|
||||||
class WavenetDecoder(nn.Module):
|
class WavenetDecoder(nn.Module):
|
||||||
|
|
||||||
|
|
@ -49,32 +56,39 @@ class WavenetEncoder(nn.Module):
|
||||||
return self.encoder(waveform)
|
return self.encoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
|
class WaveUnet(Model):
|
||||||
|
|
||||||
class WaveUnet(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
inp_channels:int=1,
|
num_channels:int=1,
|
||||||
num_layers:int=12,
|
num_layers:int=12,
|
||||||
initial_output_channels:int=24
|
initial_output_channels:int=24,
|
||||||
|
sampling_rate:int=16000,
|
||||||
|
lr:float=1e-3,
|
||||||
|
dataset:Optional[EnhancerDataset]=None,
|
||||||
|
duration:Optional[float]=None,
|
||||||
|
loss: Union[str, List] = "mse",
|
||||||
|
metric:Union[str,List] = "mse"
|
||||||
):
|
):
|
||||||
super(WaveUnet,self).__init__()
|
super().__init__(num_channels=num_channels,
|
||||||
|
sampling_rate=sampling_rate,lr=lr,
|
||||||
|
dataset=dataset,duration=duration,loss=loss, metric=metric
|
||||||
|
)
|
||||||
|
self.save_hyperparameters("num_layers")
|
||||||
self.encoders = nn.ModuleList()
|
self.encoders = nn.ModuleList()
|
||||||
self.decoders = nn.ModuleList()
|
self.decoders = nn.ModuleList()
|
||||||
out_channels = initial_output_channels
|
out_channels = initial_output_channels
|
||||||
for layer in range(num_layers):
|
for layer in range(num_layers):
|
||||||
|
|
||||||
encoder = WavenetEncoder(inp_channels,out_channels)
|
encoder = WavenetEncoder(num_channels,out_channels)
|
||||||
self.encoders.append(encoder)
|
self.encoders.append(encoder)
|
||||||
|
|
||||||
inp_channels = out_channels
|
num_channels = out_channels
|
||||||
out_channels += initial_output_channels
|
out_channels += initial_output_channels
|
||||||
if layer == num_layers -1 :
|
if layer == num_layers -1 :
|
||||||
decoder = WavenetDecoder(num_layers * initial_output_channels + inp_channels,inp_channels)
|
decoder = WavenetDecoder(num_layers * initial_output_channels + num_channels,inp_channels)
|
||||||
else:
|
else:
|
||||||
decoder = WavenetDecoder(inp_channels+out_channels,inp_channels)
|
decoder = WavenetDecoder(num_channels+out_channels,num_channels)
|
||||||
|
|
||||||
self.decoders.insert(0,decoder)
|
self.decoders.insert(0,decoder)
|
||||||
|
|
||||||
|
|
@ -85,19 +99,54 @@ class WaveUnet(nn.Module):
|
||||||
nn.BatchNorm1d(bottleneck_dim),
|
nn.BatchNorm1d(bottleneck_dim),
|
||||||
nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
)
|
)
|
||||||
|
self.final = nn.Sequential(
|
||||||
|
nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1),
|
||||||
|
nn.Tanh()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,waveform
|
self,waveform
|
||||||
):
|
):
|
||||||
|
if waveform.dim() == 2:
|
||||||
for encoder in self.encoders:
|
waveform = waveform.unsqueeze(1)
|
||||||
out = encoder(waveform)
|
|
||||||
|
|
||||||
|
if waveform.size(1)!=1:
|
||||||
|
raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels")
|
||||||
|
|
||||||
|
encoder_outputs = []
|
||||||
|
out = waveform
|
||||||
|
for encoder in self.encoders:
|
||||||
|
out = encoder(out)
|
||||||
|
encoder_outputs.insert(0,out)
|
||||||
|
out = out[:,:,::2]
|
||||||
|
|
||||||
out = self.bottleneck(out)
|
out = self.bottleneck(out)
|
||||||
|
|
||||||
for decoder in self.decoders:
|
for layer,decoder in enumerate(self.decoders):
|
||||||
|
out = F.interpolate(out, scale_factor=2, mode="linear")
|
||||||
|
print(out.shape,encoder_outputs[layer].shape)
|
||||||
|
out = self.fix_last_dim(out,encoder_outputs[layer])
|
||||||
|
out = torch.cat([out,encoder_outputs[layer]],dim=1)
|
||||||
out = decoder(out)
|
out = decoder(out)
|
||||||
|
|
||||||
return decoder
|
out = torch.cat([out, waveform],dim=1)
|
||||||
|
out = self.final(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def fix_last_dim(self,x,target):
|
||||||
|
"""
|
||||||
|
trying to do centre crop along last dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension"
|
||||||
|
if x.shape[-1] == target.shape[-1]:
|
||||||
|
return x
|
||||||
|
|
||||||
|
diff = x.shape[-1] - target.shape[-1]
|
||||||
|
if diff%2!=0:
|
||||||
|
x = F.pad(x,(0,1))
|
||||||
|
diff += 1
|
||||||
|
|
||||||
|
crop = diff//2
|
||||||
|
return x[:,:,crop:-crop]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue