set sr to dataset sr
This commit is contained in:
parent
4e033d2ab5
commit
79525df76e
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
from typing import Optional, Union, List
|
from typing import Optional, Union, List
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -112,6 +113,10 @@ class Demucs(Model):
|
||||||
|
|
||||||
):
|
):
|
||||||
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
||||||
|
if dataset is not None:
|
||||||
|
if sampling_rate!=dataset.sampling_rate:
|
||||||
|
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
|
||||||
|
sampling_rate = dataset.sampling_rate
|
||||||
super().__init__(num_channels=num_channels,
|
super().__init__(num_channels=num_channels,
|
||||||
sampling_rate=sampling_rate,lr=lr,
|
sampling_rate=sampling_rate,lr=lr,
|
||||||
dataset=dataset,duration=duration,loss=loss, metric=metric)
|
dataset=dataset,duration=duration,loss=loss, metric=metric)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from tkinter import wantobjects
|
import logging
|
||||||
import wave
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -71,7 +70,10 @@ class WaveUnet(Model):
|
||||||
metric:Union[str,List] = "mse"
|
metric:Union[str,List] = "mse"
|
||||||
):
|
):
|
||||||
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
||||||
sampling_rate = sampling_rate if dataset is None else dataset.sampling_rate
|
if dataset is not None:
|
||||||
|
if sampling_rate!=dataset.sampling_rate:
|
||||||
|
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
|
||||||
|
sampling_rate = dataset.sampling_rate
|
||||||
super().__init__(num_channels=num_channels,
|
super().__init__(num_channels=num_channels,
|
||||||
sampling_rate=sampling_rate,lr=lr,
|
sampling_rate=sampling_rate,lr=lr,
|
||||||
dataset=dataset,duration=duration,loss=loss, metric=metric
|
dataset=dataset,duration=duration,loss=loss, metric=metric
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue