tests loss
This commit is contained in:
parent
9a3473ee8d
commit
2d415407e9
|
|
@ -21,7 +21,7 @@ def test_loss_input_shapes(loss):
|
||||||
check_loss_shapes_compatibility(loss)
|
check_loss_shapes_compatibility(loss)
|
||||||
|
|
||||||
@pytest.mark.parametrize("loss",loss_functions)
|
@pytest.mark.parametrize("loss",loss_functions)
|
||||||
def test_loss_output_shapes(loss):
|
def test_loss_output_type(loss):
|
||||||
|
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000)
|
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue