34 lines
881 B
Python
34 lines
881 B
Python
from asyncio import base_tasks
|
|
import torch
|
|
import pytest
|
|
|
|
from enhancer.loss import mean_absolute_error, mean_squared_error
|
|
|
|
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
|
|
|
|
|
def check_loss_shapes_compatibility(loss_fun):
|
|
|
|
batch_size = 4
|
|
shape = (1, 1000)
|
|
loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
|
|
|
|
with pytest.raises(TypeError):
|
|
loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
|
|
|
|
|
|
@pytest.mark.parametrize("loss", loss_functions)
|
|
def test_loss_input_shapes(loss):
|
|
check_loss_shapes_compatibility(loss)
|
|
|
|
|
|
@pytest.mark.parametrize("loss", loss_functions)
|
|
def test_loss_output_type(loss):
|
|
|
|
batch_size = 4
|
|
prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
|
|
batch_size, 1, 1000
|
|
)
|
|
loss_value = loss(prediction, target)
|
|
assert isinstance(loss_value.item(), float)
|