From 79525df76ef60bc83c68b5d8e4b1c2d1dc0f4058 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 29 Sep 2022 17:20:34 +0530 Subject: [PATCH] set sr to dataset sr --- enhancer/models/demucs.py | 7 ++++++- enhancer/models/waveunet.py | 8 +++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 0bb81d1..7c9d8ff 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,7 +1,8 @@ +import logging from typing import Optional, Union, List from torch import nn import torch.nn.functional as F -import math +import math from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset @@ -112,6 +113,10 @@ class Demucs(Model): ): duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None + if dataset is not None: + if sampling_rate!=dataset.sampling_rate: + logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") + sampling_rate = dataset.sampling_rate super().__init__(num_channels=num_channels, sampling_rate=sampling_rate,lr=lr, dataset=dataset,duration=duration,loss=loss, metric=metric) diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index b354f55..f799352 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -1,5 +1,4 @@ -from tkinter import wantobjects -import wave +import logging import torch import torch.nn as nn import torch.nn.functional as F @@ -71,7 +70,10 @@ class WaveUnet(Model): metric:Union[str,List] = "mse" ): duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None - sampling_rate = sampling_rate if dataset is None else dataset.sampling_rate + if dataset is not None: + if sampling_rate!=dataset.sampling_rate: + logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") + sampling_rate = dataset.sampling_rate super().__init__(num_channels=num_channels, sampling_rate=sampling_rate,lr=lr, dataset=dataset,duration=duration,loss=loss, metric=metric