mayavoz/tests/loss_function_test.py

32 lines
858 B
Python

from asyncio import base_tasks
import torch
import pytest
from enhancer.utils.loss import mean_absolute_error, mean_squared_error
loss_functions = [mean_absolute_error(), mean_squared_error()]
def check_loss_shapes_compatibility(loss_fun):
batch_size = 4
shape = (1,1000)
loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape))
with pytest.raises(TypeError):
loss_fun(torch.rand(4,*shape),torch.rand(6,*shape))
@pytest.mark.parametrize("loss",loss_functions)
def test_loss_input_shapes(loss):
check_loss_shapes_compatibility(loss)
@pytest.mark.parametrize("loss",loss_functions)
def test_loss_output_type(loss):
batch_size = 4
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000)
loss_value = loss(prediction, target)
assert isinstance(loss_value.item(),float)