This commit is contained in:
shahules786 2022-09-05 14:52:53 +05:30
parent 88eea81261
commit b6d4bb52b4
3 changed files with 134 additions and 0 deletions

View File

91
enhancer/models/demucs.py Normal file
View File

@ -0,0 +1,91 @@
from typing import bool
from torch import nn
class DeLSTM(nn.Module):
def __init__(
self,
input_size:int,
hidden_size:int,
num_layers:int,
bidirectional:bool=True
):
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
dim = 2 if bidirectional else 1
self.linear = nn.Linear(dim*hidden_size,hidden_size)
def forward(self,x):
output,(h,c) = self.lstm(x)
output = self.linear(output)
return output
class Demus(nn.Module):
def __init__(
self,
c_in:int=1,
c_out:int=1,
hidden:int=48,
kernel_size:int=8,
stride:int=4,
growth_factor:int=2,
depth:int = 6,
glu:bool = True,
bidirectional:bool=True,
resample:int=2,
):
self.c_in = c_in
self.c_out = c_out
self.hidden = hidden
self.growth_factor = growth_factor
self.stride = stride
self.kernel_size = kernel_size
self.depth = depth
self.bidirectional = bidirectional
self.activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1
## do resampling
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for layer in range(self.depth):
encoder_layer = [nn.Conv1d(c_in,hidden,kernel_size,stride),
nn.ReLU(),
nn.Conv1d(hidden, hidden*multi_factor,kernel_size,1),
self.activation]
encoder_layer = nn.Sequential(encoder_layer)
self.encoder.append(*encoder_layer)
decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,kernel_size,1),
self.activation,
nn.ConvTranspose1d(hidden,c_out,kernel_size,stride)
]
if layer>0:
decoder_layer.append(nn.ReLU())
decoder_layer = nn.Sequential(*decoder_layer)
self.decoder.insert(0,decoder_layer)
c_out = hidden
c_in = hidden
hidden = self.growth_factor * hidden
self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional)
def forward(self,input):
pass

43
enhancer/models/model.py Normal file
View File

@ -0,0 +1,43 @@
from typing import Optional
import pytorch_lightning as pl
from enhancer.data.dataset import Dataset
class Model(pl.LightningModule):
def __init__(
self,
dataset:Dataset
):
super().__init__()
self.dataset = dataset
pass
@property
def dataset(self):
return self._dataset
@dataset.setter
def dataset(self,dataset):
self._dataset = dataset
def setup(
self,
stage:Optional[str]=None
):
if stage == "fit":
self.dataset.setup(stage)
self.dataset.model = self
def train_dataloader(
self
):
return self.dataset.train_dataloader()
def val_dataloader(
self
):
return self.dataset.val_dataloader()