33 lines
		
	
	
		
			849 B
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			33 lines
		
	
	
		
			849 B
		
	
	
	
		
			Python
		
	
	
	
| import pytest
 | |
| import torch
 | |
| 
 | |
| from mayavoz.loss import mean_absolute_error, mean_squared_error
 | |
| 
 | |
| loss_functions = [mean_absolute_error(), mean_squared_error()]
 | |
| 
 | |
| 
 | |
| def check_loss_shapes_compatibility(loss_fun):
 | |
| 
 | |
|     batch_size = 4
 | |
|     shape = (1, 1000)
 | |
|     loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
 | |
| 
 | |
|     with pytest.raises(TypeError):
 | |
|         loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("loss", loss_functions)
 | |
| def test_loss_input_shapes(loss):
 | |
|     check_loss_shapes_compatibility(loss)
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("loss", loss_functions)
 | |
| def test_loss_output_type(loss):
 | |
| 
 | |
|     batch_size = 4
 | |
|     prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
 | |
|         batch_size, 1, 1000
 | |
|     )
 | |
|     loss_value = loss(prediction, target)
 | |
|     assert isinstance(loss_value.item(), float)
 |