shape/value checks
This commit is contained in:
		
							parent
							
								
									783a440609
								
							
						
					
					
						commit
						c5fdfbe188
					
				|  | @ -100,6 +100,9 @@ class Demucs(Model): | ||||||
|         if mixed_signal.dim() == 2: |         if mixed_signal.dim() == 2: | ||||||
|             mixed_signal = mixed_signal.unsqueeze(1) |             mixed_signal = mixed_signal.unsqueeze(1) | ||||||
| 
 | 
 | ||||||
|  |         if mixed_signal.size(1)!=1: | ||||||
|  |             raise TypeError(f"Demucs can only process mono channel audio, input has {mixed_signal.size(1)} channels") | ||||||
|  | 
 | ||||||
|         length = mixed_signal.shape[-1] |         length = mixed_signal.shape[-1] | ||||||
|         x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))  |         x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))  | ||||||
|         if self.hparams.resample>1: |         if self.hparams.resample>1: | ||||||
|  |  | ||||||
|  | @ -3,7 +3,7 @@ from torch.optim import Adam | ||||||
| import pytorch_lightning as pl | import pytorch_lightning as pl | ||||||
| 
 | 
 | ||||||
| from enhancer.data.dataset import Dataset | from enhancer.data.dataset import Dataset | ||||||
| from enhancer.utils.loss import Avergeloss | from enhancer.utils.loss import LOSS_MAP, Avergeloss | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Model(pl.LightningModule): | class Model(pl.LightningModule): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786