pass loss as arg

This commit is contained in:
shahules786 2022-09-10 11:42:04 +05:30
parent 4dbefd51b3
commit 1288565cff
1 changed files with 3 additions and 2 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, Union, List
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import math import math
@ -53,10 +53,11 @@ class Demucs(Model):
sampling_rate = 16000, sampling_rate = 16000,
lr:float=1e-3, lr:float=1e-3,
dataset:Optional[EnhancerDataset]=None, dataset:Optional[EnhancerDataset]=None,
loss:Union[str, List] = "mse"
): ):
super().__init__(num_channels=num_channels, super().__init__(num_channels=num_channels,
sampling_rate=sampling_rate,lr=lr,dataset=dataset) sampling_rate=sampling_rate,lr=lr,dataset=dataset,loss)
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS,lstm) lstm = merge_dict(self.LSTM_DEFAULTS,lstm)