pass loss as arg
This commit is contained in:
parent
4dbefd51b3
commit
1288565cff
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional
|
from typing import Optional, Union, List
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
|
|
@ -53,10 +53,11 @@ class Demucs(Model):
|
||||||
sampling_rate = 16000,
|
sampling_rate = 16000,
|
||||||
lr:float=1e-3,
|
lr:float=1e-3,
|
||||||
dataset:Optional[EnhancerDataset]=None,
|
dataset:Optional[EnhancerDataset]=None,
|
||||||
|
loss:Union[str, List] = "mse"
|
||||||
|
|
||||||
):
|
):
|
||||||
super().__init__(num_channels=num_channels,
|
super().__init__(num_channels=num_channels,
|
||||||
sampling_rate=sampling_rate,lr=lr,dataset=dataset)
|
sampling_rate=sampling_rate,lr=lr,dataset=dataset,loss)
|
||||||
|
|
||||||
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
|
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
|
||||||
lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
|
lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue