shape/value checks

This commit is contained in:
shahules786 2022-09-14 11:50:30 +05:30
parent 783a440609
commit c5fdfbe188
2 changed files with 4 additions and 1 deletions

View File

@ -100,6 +100,9 @@ class Demucs(Model):
if mixed_signal.dim() == 2:
mixed_signal = mixed_signal.unsqueeze(1)
if mixed_signal.size(1)!=1:
raise TypeError(f"Demucs can only process mono channel audio, input has {mixed_signal.size(1)} channels")
length = mixed_signal.shape[-1]
x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))
if self.hparams.resample>1:

View File

@ -3,7 +3,7 @@ from torch.optim import Adam
import pytorch_lightning as pl
from enhancer.data.dataset import Dataset
from enhancer.utils.loss import Avergeloss
from enhancer.utils.loss import LOSS_MAP, Avergeloss
class Model(pl.LightningModule):