refactor models

This commit is contained in:
shahules786 2022-10-05 12:50:26 +05:30
parent 2cf9803ed1
commit 451058c29d
2 changed files with 283 additions and 192 deletions

View File

@ -9,209 +9,255 @@ from enhancer.data.dataset import EnhancerDataset
from enhancer.utils.io import Audio as audio from enhancer.utils.io import Audio as audio
from enhancer.utils.utils import merge_dict from enhancer.utils.utils import merge_dict
class DemucsLSTM(nn.Module): class DemucsLSTM(nn.Module):
def __init__( def __init__(
self, self,
input_size:int, input_size: int,
hidden_size:int, hidden_size: int,
num_layers:int, num_layers: int,
bidirectional:bool=True bidirectional: bool = True,
): ):
super().__init__() super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional) self.lstm = nn.LSTM(
input_size, hidden_size, num_layers, bidirectional=bidirectional
)
dim = 2 if bidirectional else 1 dim = 2 if bidirectional else 1
self.linear = nn.Linear(dim*hidden_size,hidden_size) self.linear = nn.Linear(dim * hidden_size, hidden_size)
def forward(self,x): def forward(self, x):
output,(h,c) = self.lstm(x) output, (h, c) = self.lstm(x)
output = self.linear(output) output = self.linear(output)
return output,(h,c) return output, (h, c)
class DemucsEncoder(nn.Module): class DemucsEncoder(nn.Module):
def __init__( def __init__(
self, self,
num_channels:int, num_channels: int,
hidden_size:int, hidden_size: int,
kernel_size:int, kernel_size: int,
stride:int=1, stride: int = 1,
glu:bool=False, glu: bool = False,
): ):
super().__init__() super().__init__()
activation = nn.GLU(1) if glu else nn.ReLU() activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1 multi_factor = 2 if glu else 1
self.encoder = nn.Sequential( self.encoder = nn.Sequential(
nn.Conv1d(num_channels,hidden_size,kernel_size,stride), nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1), nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
activation activation,
) )
def forward(self,waveform): def forward(self, waveform):
return self.encoder(waveform) return self.encoder(waveform)
class DemucsDecoder(nn.Module):
class DemucsDecoder(nn.Module):
def __init__( def __init__(
self, self,
num_channels:int, num_channels: int,
hidden_size:int, hidden_size: int,
kernel_size:int, kernel_size: int,
stride:int=1, stride: int = 1,
glu:bool=False, glu: bool = False,
layer:int=0 layer: int = 0,
): ):
super().__init__() super().__init__()
activation = nn.GLU(1) if glu else nn.ReLU() activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1 multi_factor = 2 if glu else 1
self.decoder = nn.Sequential( self.decoder = nn.Sequential(
nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1), nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
activation, activation,
nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride) nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
) )
if layer>0: if layer > 0:
self.decoder.add_module("4", nn.ReLU()) self.decoder.add_module("4", nn.ReLU())
def forward(self,waveform,): def forward(
self,
waveform,
):
out = self.decoder(waveform) out = self.decoder(waveform)
return out return out
class Demucs(Model): class Demucs(Model):
"""
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
parameters:
encoder_decoder: dict, optional
keyword arguments passsed to encoder decoder block
lstm : dict, optional
keyword arguments passsed to LSTM block
num_channels: int, defaults to 1
number channels in input audio
sampling_rate: int, defaults to 16KHz
sampling rate of input audio
lr : float, defaults to 1e-3
learning rate used for training
dataset: EnhancerDataset, optional
EnhancerDataset object containing train/validation data for training
duration : float, optional
chunk duration in seconds
loss : string or List of strings
loss function to be used, available ("mse","mae","SI-SDR")
metric : string or List of strings
metric function to be used, available ("mse","mae","SI-SDR")
"""
ED_DEFAULTS = { ED_DEFAULTS = {
"initial_output_channels":48, "initial_output_channels": 48,
"kernel_size":8, "kernel_size": 8,
"stride":1, "stride": 1,
"depth":5, "depth": 5,
"glu":True, "glu": True,
"growth_factor":2, "growth_factor": 2,
} }
LSTM_DEFAULTS = { LSTM_DEFAULTS = {
"bidirectional":True, "bidirectional": True,
"num_layers":2, "num_layers": 2,
} }
def __init__( def __init__(
self, self,
encoder_decoder:Optional[dict]=None, encoder_decoder: Optional[dict] = None,
lstm:Optional[dict]=None, lstm: Optional[dict] = None,
num_channels:int=1, num_channels: int = 1,
resample:int=4, resample: int = 4,
sampling_rate = 16000, sampling_rate=16000,
lr:float=1e-3, lr: float = 1e-3,
dataset:Optional[EnhancerDataset]=None, dataset: Optional[EnhancerDataset] = None,
loss:Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric:Union[str, List] = "mse" metric: Union[str, List] = "mse",
): ):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None: if dataset is not None:
if sampling_rate!=dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") logging.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate sampling_rate = dataset.sampling_rate
super().__init__(num_channels=num_channels, super().__init__(
sampling_rate=sampling_rate,lr=lr, num_channels=num_channels,
dataset=dataset,duration=duration,loss=loss, metric=metric) sampling_rate=sampling_rate,
lr=lr,
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) dataset=dataset,
lstm = merge_dict(self.LSTM_DEFAULTS,lstm) duration=duration,
self.save_hyperparameters("encoder_decoder","lstm","resample") loss=loss,
metric=metric,
)
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
hidden = encoder_decoder["initial_output_channels"] hidden = encoder_decoder["initial_output_channels"]
self.encoder = nn.ModuleList() self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList() self.decoder = nn.ModuleList()
for layer in range(encoder_decoder["depth"]): for layer in range(encoder_decoder["depth"]):
encoder_layer = DemucsEncoder(num_channels=num_channels, encoder_layer = DemucsEncoder(
hidden_size=hidden, num_channels=num_channels,
kernel_size=encoder_decoder["kernel_size"], hidden_size=hidden,
stride=encoder_decoder["stride"], kernel_size=encoder_decoder["kernel_size"],
glu=encoder_decoder["glu"], stride=encoder_decoder["stride"],
) glu=encoder_decoder["glu"],
)
self.encoder.append(encoder_layer) self.encoder.append(encoder_layer)
decoder_layer = DemucsDecoder(num_channels=num_channels, decoder_layer = DemucsDecoder(
hidden_size=hidden, num_channels=num_channels,
kernel_size=encoder_decoder["kernel_size"], hidden_size=hidden,
stride=1, kernel_size=encoder_decoder["kernel_size"],
glu=encoder_decoder["glu"], stride=1,
layer=layer glu=encoder_decoder["glu"],
) layer=layer,
self.decoder.insert(0,decoder_layer) )
self.decoder.insert(0, decoder_layer)
num_channels = hidden num_channels = hidden
hidden = self.ED_DEFAULTS["growth_factor"] * hidden hidden = self.ED_DEFAULTS["growth_factor"] * hidden
self.de_lstm = DemucsLSTM(input_size=num_channels,
hidden_size=num_channels,
num_layers=lstm["num_layers"],
bidirectional=lstm["bidirectional"]
)
def forward(self,waveform): self.de_lstm = DemucsLSTM(
input_size=num_channels,
hidden_size=num_channels,
num_layers=lstm["num_layers"],
bidirectional=lstm["bidirectional"],
)
def forward(self, waveform):
if waveform.dim() == 2: if waveform.dim() == 2:
waveform = waveform.unsqueeze(1) waveform = waveform.unsqueeze(1)
if waveform.size(1)!=1: if waveform.size(1) != 1:
raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels") raise TypeError(
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
)
length = waveform.shape[-1] length = waveform.shape[-1]
x = F.pad(waveform, (0,self.get_padding_length(length) - length)) x = F.pad(waveform, (0, self.get_padding_length(length) - length))
if self.hparams.resample>1: if self.hparams.resample > 1:
x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate, x = audio.resample_audio(
target_sr=int(self.hparams.sampling_rate * self.hparams.resample)) audio=x,
sr=self.hparams.sampling_rate,
target_sr=int(
self.hparams.sampling_rate * self.hparams.resample
),
)
encoder_outputs = [] encoder_outputs = []
for encoder in self.encoder: for encoder in self.encoder:
x = encoder(x) x = encoder(x)
encoder_outputs.append(x) encoder_outputs.append(x)
x = x.permute(0,2,1) x = x.permute(0, 2, 1)
x,_ = self.de_lstm(x) x, _ = self.de_lstm(x)
x = x.permute(0,2,1) x = x.permute(0, 2, 1)
for decoder in self.decoder: for decoder in self.decoder:
skip_connection = encoder_outputs.pop(-1) skip_connection = encoder_outputs.pop(-1)
x += skip_connection[..., :x.shape[-1]] x += skip_connection[..., : x.shape[-1]]
x = decoder(x) x = decoder(x)
if self.hparams.resample > 1: if self.hparams.resample > 1:
x = audio.resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample), x = audio.resample_audio(
self.hparams.sampling_rate) x,
int(self.hparams.sampling_rate * self.hparams.resample),
self.hparams.sampling_rate,
)
return x return x
def get_padding_length(self,input_length): def get_padding_length(self, input_length):
input_length = math.ceil(input_length * self.hparams.resample) input_length = math.ceil(input_length * self.hparams.resample)
for layer in range(
for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation self.hparams.encoder_decoder["depth"]
input_length = math.ceil((input_length - self.hparams.encoder_decoder["kernel_size"])/self.hparams.encoder_decoder["stride"])+1 ): # encoder operation
input_length = max(1,input_length) input_length = (
for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration math.ceil(
input_length = (input_length-1) * self.hparams.encoder_decoder["stride"] + self.hparams.encoder_decoder["kernel_size"] (input_length - self.hparams.encoder_decoder["kernel_size"])
input_length = math.ceil(input_length/self.hparams.resample) / self.hparams.encoder_decoder["stride"]
)
+ 1
)
input_length = max(1, input_length)
for layer in range(
self.hparams.encoder_decoder["depth"]
): # decoder operaration
input_length = (input_length - 1) * self.hparams.encoder_decoder[
"stride"
] + self.hparams.encoder_decoder["kernel_size"]
input_length = math.ceil(input_length / self.hparams.resample)
return int(input_length) return int(input_length)

View File

@ -7,76 +7,117 @@ from typing import Optional, Union, List
from enhancer.models.model import Model from enhancer.models.model import Model
from enhancer.data.dataset import EnhancerDataset from enhancer.data.dataset import EnhancerDataset
class WavenetDecoder(nn.Module):
class WavenetDecoder(nn.Module):
def __init__( def __init__(
self, self,
in_channels:int, in_channels: int,
out_channels:int, out_channels: int,
kernel_size:int=5, kernel_size: int = 5,
padding:int=2, padding: int = 2,
stride:int=1, stride: int = 1,
dilation:int=1, dilation: int = 1,
): ):
super(WavenetDecoder,self).__init__() super(WavenetDecoder, self).__init__()
self.decoder = nn.Sequential( self.decoder = nn.Sequential(
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
nn.LeakyReLU(negative_slope=0.1) nn.LeakyReLU(negative_slope=0.1),
) )
def forward(self,waveform): def forward(self, waveform):
return self.decoder(waveform) return self.decoder(waveform)
class WavenetEncoder(nn.Module):
class WavenetEncoder(nn.Module):
def __init__( def __init__(
self, self,
in_channels:int, in_channels: int,
out_channels:int, out_channels: int,
kernel_size:int=15, kernel_size: int = 15,
padding:int=7, padding: int = 7,
stride:int=1, stride: int = 1,
dilation:int=1, dilation: int = 1,
): ):
super(WavenetEncoder,self).__init__() super(WavenetEncoder, self).__init__()
self.encoder = nn.Sequential( self.encoder = nn.Sequential(
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
nn.LeakyReLU(negative_slope=0.1) nn.LeakyReLU(negative_slope=0.1),
) )
def forward( def forward(self, waveform):
self,
waveform
):
return self.encoder(waveform) return self.encoder(waveform)
class WaveUnet(Model): class WaveUnet(Model):
"""
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
parameters:
num_channels: int, defaults to 1
number of channels in input audio
depth : int, defaults to 12
depth of network
initial_output_channels: int, defaults to 24
number of output channels in intial upsampling layer
sampling_rate: int, defaults to 16KHz
sampling rate of input audio
lr : float, defaults to 1e-3
learning rate used for training
dataset: EnhancerDataset, optional
EnhancerDataset object containing train/validation data for training
duration : float, optional
chunk duration in seconds
loss : string or List of strings
loss function to be used, available ("mse","mae","SI-SDR")
metric : string or List of strings
metric function to be used, available ("mse","mae","SI-SDR")
"""
def __init__( def __init__(
self, self,
num_channels:int=1, num_channels: int = 1,
depth:int=12, depth: int = 12,
initial_output_channels:int=24, initial_output_channels: int = 24,
sampling_rate:int=16000, sampling_rate: int = 16000,
lr:float=1e-3, lr: float = 1e-3,
dataset:Optional[EnhancerDataset]=None, dataset: Optional[EnhancerDataset] = None,
duration:Optional[float]=None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric:Union[str,List] = "mse" metric: Union[str, List] = "mse",
): ):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None: if dataset is not None:
if sampling_rate!=dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") logging.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate sampling_rate = dataset.sampling_rate
super().__init__(num_channels=num_channels, super().__init__(
sampling_rate=sampling_rate,lr=lr, num_channels=num_channels,
dataset=dataset,duration=duration,loss=loss, metric=metric sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
) )
self.save_hyperparameters("depth") self.save_hyperparameters("depth")
self.encoders = nn.ModuleList() self.encoders = nn.ModuleList()
@ -84,72 +125,76 @@ class WaveUnet(Model):
out_channels = initial_output_channels out_channels = initial_output_channels
for layer in range(depth): for layer in range(depth):
encoder = WavenetEncoder(num_channels,out_channels) encoder = WavenetEncoder(num_channels, out_channels)
self.encoders.append(encoder) self.encoders.append(encoder)
num_channels = out_channels num_channels = out_channels
out_channels += initial_output_channels out_channels += initial_output_channels
if layer == depth -1 : if layer == depth - 1:
decoder = WavenetDecoder(depth * initial_output_channels + num_channels,num_channels) decoder = WavenetDecoder(
depth * initial_output_channels + num_channels, num_channels
)
else: else:
decoder = WavenetDecoder(num_channels+out_channels,num_channels) decoder = WavenetDecoder(
num_channels + out_channels, num_channels
)
self.decoders.insert(0,decoder) self.decoders.insert(0, decoder)
bottleneck_dim = depth * initial_output_channels bottleneck_dim = depth * initial_output_channels
self.bottleneck = nn.Sequential( self.bottleneck = nn.Sequential(
nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1, nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7),
padding=7),
nn.BatchNorm1d(bottleneck_dim), nn.BatchNorm1d(bottleneck_dim),
nn.LeakyReLU(negative_slope=0.1, inplace=True) nn.LeakyReLU(negative_slope=0.1, inplace=True),
) )
self.final = nn.Sequential( self.final = nn.Sequential(
nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1), nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1),
nn.Tanh() nn.Tanh(),
) )
def forward( def forward(self, waveform):
self,waveform
):
if waveform.dim() == 2: if waveform.dim() == 2:
waveform = waveform.unsqueeze(1) waveform = waveform.unsqueeze(1)
if waveform.size(1)!=1: if waveform.size(1) != 1:
raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels") raise TypeError(
f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels"
)
encoder_outputs = [] encoder_outputs = []
out = waveform out = waveform
for encoder in self.encoders: for encoder in self.encoders:
out = encoder(out) out = encoder(out)
encoder_outputs.insert(0,out) encoder_outputs.insert(0, out)
out = out[:,:,::2] out = out[:, :, ::2]
out = self.bottleneck(out) out = self.bottleneck(out)
for layer,decoder in enumerate(self.decoders): for layer, decoder in enumerate(self.decoders):
out = F.interpolate(out, scale_factor=2, mode="linear") out = F.interpolate(out, scale_factor=2, mode="linear")
out = self.fix_last_dim(out,encoder_outputs[layer]) out = self.fix_last_dim(out, encoder_outputs[layer])
out = torch.cat([out,encoder_outputs[layer]],dim=1) out = torch.cat([out, encoder_outputs[layer]], dim=1)
out = decoder(out) out = decoder(out)
out = torch.cat([out, waveform],dim=1) out = torch.cat([out, waveform], dim=1)
out = self.final(out) out = self.final(out)
return out return out
def fix_last_dim(self,x,target): def fix_last_dim(self, x, target):
""" """
trying to do centre crop along last dimension centre crop along last dimension
""" """
assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension" assert (
x.shape[-1] >= target.shape[-1]
), "input dimension cannot be larger than target dimension"
if x.shape[-1] == target.shape[-1]: if x.shape[-1] == target.shape[-1]:
return x return x
diff = x.shape[-1] - target.shape[-1] diff = x.shape[-1] - target.shape[-1]
if diff%2!=0: if diff % 2 != 0:
x = F.pad(x,(0,1)) x = F.pad(x, (0, 1))
diff += 1 diff += 1
crop = diff//2 crop = diff // 2
return x[:,:,crop:-crop] return x[:, :, crop:-crop]