diff --git a/mayavoz/models/__init__.py b/mayavoz/models/__init__.py index 9cf2b9b..6e82eb3 100644 --- a/mayavoz/models/__init__.py +++ b/mayavoz/models/__init__.py @@ -1,3 +1,3 @@ from mayavoz.models.demucs import Demucs -from mayavoz.models.model import Model +from mayavoz.models.model import Mayamodel from mayavoz.models.waveunet import WaveUnet diff --git a/mayavoz/models/dccrn.py b/mayavoz/models/dccrn.py index 372696c..278072f 100644 --- a/mayavoz/models/dccrn.py +++ b/mayavoz/models/dccrn.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from torch import nn from mayavoz.data import EnhancerDataset -from mayavoz.models import Model +from mayavoz.models import Mayamodel from mayavoz.models.complexnn import ( ComplexBatchNorm2D, ComplexConv2d, @@ -98,7 +98,7 @@ class DCCRN_DECODER(nn.Module): return self.decoder(waveform) -class DCCRN(Model): +class DCCRN(Mayamodel): STFT_DEFAULTS = { "window_len": 400, diff --git a/mayavoz/models/demucs.py b/mayavoz/models/demucs.py index a5e3147..db69c80 100644 --- a/mayavoz/models/demucs.py +++ b/mayavoz/models/demucs.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from torch import nn from mayavoz.data.dataset import EnhancerDataset -from mayavoz.models.model import Model +from mayavoz.models.model import Mayamodel from mayavoz.utils.io import Audio as audio from mayavoz.utils.utils import merge_dict @@ -88,7 +88,7 @@ class DemucsDecoder(nn.Module): return out -class Demucs(Model): +class Demucs(Mayamodel): """ Demucs model from https://arxiv.org/pdf/1911.13254.pdf parameters: diff --git a/mayavoz/models/model.py b/mayavoz/models/model.py index d82c5c5..aede7a3 100644 --- a/mayavoz/models/model.py +++ b/mayavoz/models/model.py @@ -27,7 +27,7 @@ DEFAULT_DEVICE = "cpu" SAVE_NAME = "enhancer" -class Model(pl.LightningModule): +class Mayamodel(pl.LightningModule): """ Base class for all models parameters: @@ -288,8 +288,8 @@ class Model(pl.LightningModule): Returns ------- - model : Model - Model + model : Mayamodel + Mayamodel See also -------- diff --git a/mayavoz/models/waveunet.py b/mayavoz/models/waveunet.py index ead2146..9e5a4ae 100644 --- a/mayavoz/models/waveunet.py +++ b/mayavoz/models/waveunet.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from mayavoz.data.dataset import EnhancerDataset -from mayavoz.models.model import Model +from mayavoz.models.model import Mayamodel class WavenetDecoder(nn.Module): @@ -66,7 +66,7 @@ class WavenetEncoder(nn.Module): return self.encoder(waveform) -class WaveUnet(Model): +class WaveUnet(Mayamodel): """ Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf parameters: