You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When wrapping the binary_crossentropy loss function in another keras.losses.Loss, it no longer supports targets with an flat shape and requires a shape of form (..., 1). This does not happen when it is simply wrapped in a function or a class with a __call__() method.
How to reproduce
The following script can be used to reproduce this error.
File "/keras/keras/src/losses/loss.py", line 43, in __call__
losses = self.call(y_true, y_pred)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/keras/reproduce_keras_error.py", line 45, in call
return self.loss(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/keras/keras/src/losses/losses.py", line 1782, in binary_crossentropy
ops.binary_crossentropy(y_true, y_pred, from_logits=from_logits),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/keras/keras/src/ops/nn.py", line 1398, in binary_crossentropy
return backend.nn.binary_crossentropy(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/keras/keras/src/backend/jax/nn.py", line 518, in binary_crossentropy
raise ValueError(
ValueError: Arguments `target` and `output` must have the same shape. Received: target.shape=(16,), output.shape=(16, 1)
Is there a recommended way?
In case this is an expected behaviour, what is the recommended way to wrap a loss function as a keras.losses.Loss class and handle both flat and (..., 1) target shapes?
The text was updated successfully, but these errors were encountered:
Just like in layers, you're supposed to override call(), not __call__(). You could override __call__(), but by doing so you miss a bit of built-in functionality, including auto-broadcasting. So you can just override call() and call self.loss(y_true, y_pred) there.
Problem
When wrapping the
binary_crossentropy
loss function in anotherkeras.losses.Loss
, it no longer supports targets with an flat shape and requires a shape of form(..., 1)
. This does not happen when it is simply wrapped in a function or a class with a__call__()
method.How to reproduce
The following script can be used to reproduce this error.
The error is the following:
Is there a recommended way?
In case this is an expected behaviour, what is the recommended way to wrap a loss function as a
keras.losses.Loss
class and handle both flat and(..., 1)
target shapes?The text was updated successfully, but these errors were encountered: