tests loss

This commit is contained in:
shahules786 2022-09-14 11:48:46 +05:30
parent 9a3473ee8d
commit 2d415407e9
1 changed files with 1 additions and 1 deletions

View File

@ -21,7 +21,7 @@ def test_loss_input_shapes(loss):
check_loss_shapes_compatibility(loss)
@pytest.mark.parametrize("loss",loss_functions)
def test_loss_output_shapes(loss):
def test_loss_output_type(loss):
batch_size = 4
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000)