shape/value checks
This commit is contained in:
		
							parent
							
								
									783a440609
								
							
						
					
					
						commit
						c5fdfbe188
					
				|  | @ -100,6 +100,9 @@ class Demucs(Model): | |||
|         if mixed_signal.dim() == 2: | ||||
|             mixed_signal = mixed_signal.unsqueeze(1) | ||||
| 
 | ||||
|         if mixed_signal.size(1)!=1: | ||||
|             raise TypeError(f"Demucs can only process mono channel audio, input has {mixed_signal.size(1)} channels") | ||||
| 
 | ||||
|         length = mixed_signal.shape[-1] | ||||
|         x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))  | ||||
|         if self.hparams.resample>1: | ||||
|  |  | |||
|  | @ -3,7 +3,7 @@ from torch.optim import Adam | |||
| import pytorch_lightning as pl | ||||
| 
 | ||||
| from enhancer.data.dataset import Dataset | ||||
| from enhancer.utils.loss import Avergeloss | ||||
| from enhancer.utils.loss import LOSS_MAP, Avergeloss | ||||
| 
 | ||||
| 
 | ||||
| class Model(pl.LightningModule): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786