diff --git a/tests/loss_function_test.py b/tests/loss_function_test.py new file mode 100644 index 0000000..f08ac65 --- /dev/null +++ b/tests/loss_function_test.py @@ -0,0 +1,31 @@ +from asyncio import base_tasks +import torch +import pytest + +from enhancer.utils.loss import mean_absolute_error, mean_squared_error + +loss_functions = [mean_absolute_error(), mean_squared_error()] + +def check_loss_shapes_compatibility(loss_fun): + + batch_size = 4 + shape = (1,1000) + loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape)) + + with pytest.raises(TypeError): + loss_fun(torch.rand(4,*shape),torch.rand(6,*shape)) + + +@pytest.mark.parametrize("loss",loss_functions) +def test_loss_input_shapes(loss): + check_loss_shapes_compatibility(loss) + +@pytest.mark.parametrize("loss",loss_functions) +def test_loss_output_shapes(loss): + + batch_size = 4 + prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000) + loss_value = loss(prediction, target) + assert isinstance(loss_value.item(),float) + +