1 coral_loss

rri_metrics.coral_loss(
    logits,
    labels,
    *,
    num_classes,
    importance_weights=None,
    reduction='mean',
)

Compute CORAL loss (sum of binary cross-entropies over thresholds).

1.1 Parameters

Name Type Description Default
logits Tensor Tensor["... K-1", float32] threshold logits. required
labels Tensor Tensor["...", int64] ordinal labels in [0, K-1]. required
num_classes int Number of ordinal classes K. required
importance_weights Tensor | None Optional per-threshold weights Tensor["K-1"]. None
reduction Literal['mean', 'sum', 'none'] Reduction mode. 'mean'

1.2 Returns

Name Type Description
Tensor Loss tensor. If reduction=“none”: Tensor["..."].