coral_loss
rri_metrics.coral.coral_loss(
logits,
labels,
*,
num_classes,
importance_weights=None,
reduction='mean',
)
Compute CORAL loss (sum of binary cross-entropies over thresholds).
Parameters
| logits |
Tensor |
Tensor["... K-1", float32] threshold logits. |
required |
| labels |
Tensor |
Tensor["...", int64] ordinal labels in [0, K-1]. |
required |
| num_classes |
int |
Number of ordinal classes K. |
required |
| importance_weights |
Tensor | None |
Optional per-threshold weights Tensor["K-1"]. |
None |
| reduction |
Literal['mean', 'sum', 'none'] |
Reduction mode. |
'mean' |
Returns
|
Tensor |
Loss tensor. If reduction=“none”: Tensor["..."]. |