From 1288565cff646b068c6e7749d2c4104beb048b4a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 10 Sep 2022 11:42:04 +0530 Subject: [PATCH] pass loss as arg --- enhancer/models/demucs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index bcc8214..6172714 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union, List from torch import nn import torch.nn.functional as F import math @@ -53,10 +53,11 @@ class Demucs(Model): sampling_rate = 16000, lr:float=1e-3, dataset:Optional[EnhancerDataset]=None, + loss:Union[str, List] = "mse" ): super().__init__(num_channels=num_channels, - sampling_rate=sampling_rate,lr=lr,dataset=dataset) + sampling_rate=sampling_rate,lr=lr,dataset=dataset,loss) encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) lstm = merge_dict(self.LSTM_DEFAULTS,lstm)