set sr to dataset sr

This commit is contained in:
shahules786 2022-09-29 17:20:34 +05:30
parent 4e033d2ab5
commit 79525df76e
2 changed files with 11 additions and 4 deletions

View File

@ -1,7 +1,8 @@
import logging
from typing import Optional, Union, List
from torch import nn
import torch.nn.functional as F
import math
import math
from enhancer.models.model import Model
from enhancer.data.dataset import EnhancerDataset
@ -112,6 +113,10 @@ class Demucs(Model):
):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
if dataset is not None:
if sampling_rate!=dataset.sampling_rate:
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
sampling_rate = dataset.sampling_rate
super().__init__(num_channels=num_channels,
sampling_rate=sampling_rate,lr=lr,
dataset=dataset,duration=duration,loss=loss, metric=metric)

View File

@ -1,5 +1,4 @@
from tkinter import wantobjects
import wave
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -71,7 +70,10 @@ class WaveUnet(Model):
metric:Union[str,List] = "mse"
):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
sampling_rate = sampling_rate if dataset is None else dataset.sampling_rate
if dataset is not None:
if sampling_rate!=dataset.sampling_rate:
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
sampling_rate = dataset.sampling_rate
super().__init__(num_channels=num_channels,
sampling_rate=sampling_rate,lr=lr,
dataset=dataset,duration=duration,loss=loss, metric=metric