diff --git a/enhancer/models/__init__.py b/enhancer/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py new file mode 100644 index 0000000..c1b2307 --- /dev/null +++ b/enhancer/models/demucs.py @@ -0,0 +1,91 @@ +from typing import bool +from torch import nn + +class DeLSTM(nn.Module): + def __init__( + self, + input_size:int, + hidden_size:int, + num_layers:int, + bidirectional:bool=True + + ): + self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional) + dim = 2 if bidirectional else 1 + self.linear = nn.Linear(dim*hidden_size,hidden_size) + + def forward(self,x): + + output,(h,c) = self.lstm(x) + output = self.linear(output) + + return output + +class Demus(nn.Module): + + def __init__( + self, + c_in:int=1, + c_out:int=1, + hidden:int=48, + kernel_size:int=8, + stride:int=4, + growth_factor:int=2, + depth:int = 6, + glu:bool = True, + bidirectional:bool=True, + resample:int=2, + + ): + self.c_in = c_in + self.c_out = c_out + self.hidden = hidden + self.growth_factor = growth_factor + self.stride = stride + self.kernel_size = kernel_size + self.depth = depth + self.bidirectional = bidirectional + self.activation = nn.GLU(1) if glu else nn.ReLU() + multi_factor = 2 if glu else 1 + + ## do resampling + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for layer in range(self.depth): + + encoder_layer = [nn.Conv1d(c_in,hidden,kernel_size,stride), + nn.ReLU(), + nn.Conv1d(hidden, hidden*multi_factor,kernel_size,1), + self.activation] + encoder_layer = nn.Sequential(encoder_layer) + self.encoder.append(*encoder_layer) + + decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,kernel_size,1), + self.activation, + nn.ConvTranspose1d(hidden,c_out,kernel_size,stride) + ] + if layer>0: + decoder_layer.append(nn.ReLU()) + decoder_layer = nn.Sequential(*decoder_layer) + self.decoder.insert(0,decoder_layer) + + c_out = hidden + c_in = hidden + hidden = self.growth_factor * hidden + + + self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional) + + def forward(self,input): + pass + + + + + + + + + \ No newline at end of file diff --git a/enhancer/models/model.py b/enhancer/models/model.py new file mode 100644 index 0000000..c7bfe30 --- /dev/null +++ b/enhancer/models/model.py @@ -0,0 +1,43 @@ +from typing import Optional +import pytorch_lightning as pl + +from enhancer.data.dataset import Dataset + + +class Model(pl.LightningModule): + + def __init__( + self, + dataset:Dataset + ): + super().__init__() + self.dataset = dataset + + pass + + @property + def dataset(self): + return self._dataset + + @dataset.setter + def dataset(self,dataset): + self._dataset = dataset + + def setup( + self, + stage:Optional[str]=None + ): + if stage == "fit": + self.dataset.setup(stage) + self.dataset.model = self + + + def train_dataloader( + self + ): + return self.dataset.train_dataloader() + + def val_dataloader( + self + ): + return self.dataset.val_dataloader()