tests loss
This commit is contained in:
parent
e4cd6ef10e
commit
9a3473ee8d
|
|
@ -0,0 +1,31 @@
|
|||
from asyncio import base_tasks
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from enhancer.utils.loss import mean_absolute_error, mean_squared_error
|
||||
|
||||
loss_functions = [mean_absolute_error(), mean_squared_error()]
|
||||
|
||||
def check_loss_shapes_compatibility(loss_fun):
|
||||
|
||||
batch_size = 4
|
||||
shape = (1,1000)
|
||||
loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape))
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
loss_fun(torch.rand(4,*shape),torch.rand(6,*shape))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("loss",loss_functions)
|
||||
def test_loss_input_shapes(loss):
|
||||
check_loss_shapes_compatibility(loss)
|
||||
|
||||
@pytest.mark.parametrize("loss",loss_functions)
|
||||
def test_loss_output_shapes(loss):
|
||||
|
||||
batch_size = 4
|
||||
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000)
|
||||
loss_value = loss(prediction, target)
|
||||
assert isinstance(loss_value.item(),float)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue