milliontrees.common.metrics package

Submodules

milliontrees.common.metrics.all_metrics module

class milliontrees.common.metrics.all_metrics.Accuracy(prediction_fn=None, name=None)[source]

Bases: ElementwiseMetric

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.CountingError(score_threshold=0.1, name=None, geometry_name='y')[source]

Bases: ElementwiseMetric

Mean Absolute Error between ground truth and predicted detection counts.

Calculates MAE between the number of detections in ground truth vs predictions for each sample in the batch.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.DetectionAccuracy(iou_threshold=0.4, score_threshold=0.1, name=None, geometry_name='boxes', metric='accuracy')[source]

Bases: ElementwiseMetric

Per-image detection recall or accuracy with greedy 1:1 IoU matching.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.DetectionMAP(geometry_name='y', score_threshold=0.1, iou_type='bbox', name=None)[source]

Bases: Metric

Mean Average Precision for object detection using torchmetrics.

Supports bounding boxes (iou_type=”bbox”) and instance segmentation masks (iou_type=”segm”). Single-class: all labels are normalised to 0 so that predictions always match the ground-truth class regardless of the model’s raw label output.

property agg_metric_field

The name of the key in the results dictionary returned by Metric.compute().

This should correspond to the aggregate metric computed on all of y_pred and y_true, in contrast to a group-wise evaluation.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.DummyMetric(prediction_fn=None, name=None)[source]

Bases: Metric

For testing purposes.

This Metric always returns -1.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.F1(prediction_fn=None, name=None, average='binary')[source]

Bases: Metric

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.KeypointAccuracy(distance_threshold: float = 0.02, score_threshold: float = 0.1, name: str | None = None, geometry_name: str = 'y', image_size: int = 448)[source]

Bases: ElementwiseMetric

Keypoint accuracy for a one-class detector.

The distance_threshold is interpreted as a normalized distance with respect to the image size rather than raw pixels. For a square image of side length image_size, the effective pixel threshold is

pixel_threshold = distance_threshold * image_size

This makes the metric less sensitive to the absolute crop size while still behaving like a fixed-radius matching rule in pixel space.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.KeypointMergeCommissionMetric(distance_threshold: float = 0.02, score_threshold: float = 0.1, geometry_name: str = 'y', image_size: int = 448, name: str | None = None)[source]

Bases: ElementwiseMetric

Fraction of predictions within max_distance of two or more GT points.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.MSE(name=None)[source]

Bases: ElementwiseLoss

class milliontrees.common.metrics.all_metrics.MaskAccuracy(iou_threshold=0.4, score_threshold=0.1, name=None, geometry_name='masks', metric='accuracy')[source]

Bases: ElementwiseMetric

Per-image mask recall or accuracy with greedy 1:1 mask IoU matching.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.MaskAwareDetectionPrecision(iou_threshold=0.4, score_threshold=0.1, tree_fraction_threshold=0.5, require_tree_coverage_mask=False, name=None, geometry_name='boxes', tree_coverage_key='tree_coverage_mask')[source]

Bases: ElementwiseMetric

Precision metric that avoids penalizing predictions on unannotated tree regions.

Unmatched predictions are excluded from false positives when enough of their box area overlaps tree pixels from tree_coverage_mask.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.MaskAwareKeypointPrecision(distance_threshold: float = 0.02, score_threshold: float = 0.1, tree_fraction_threshold: float = 0.5, require_tree_coverage_mask: bool = False, name: str | None = None, geometry_name: str = 'y', tree_coverage_key: str = 'tree_coverage_mask', image_size: int = 448)[source]

Bases: ElementwiseMetric

Precision for point detection that ignores unmatched points on tree-covered pixels.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.MaskAwareMaskPrecision(iou_threshold=0.4, score_threshold=0.1, tree_fraction_threshold=0.5, require_tree_coverage_mask=False, name=None, geometry_name='masks', tree_coverage_key='tree_coverage_mask')[source]

Bases: ElementwiseMetric

Precision for mask detection that ignores unmatched masks on tree-covered pixels.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.MergeCommissionMetric(iou_threshold: float = 0.4, score_threshold: float = 0.1, geometry_name: str = 'y', modality: str = 'bbox', name: str | None = None)[source]

Bases: ElementwiseMetric

Fraction of predictions with IoU > iou_threshold against two or more GT objects.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.MultiTaskAccuracy(prediction_fn=None, name=None)[source]

Bases: MultiTaskMetric

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.MultiTaskAveragePrecision(prediction_fn=None, name=None, average='macro')[source]

Bases: MultiTaskMetric

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.PearsonCorrelation(name=None)[source]

Bases: Metric

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.PrecisionAtRecall(threshold, score_fn=None, name=None)[source]

Bases: Metric

Given a specific model threshold, determine the precision score achieved.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.all_metrics.Recall(prediction_fn=None, name=None, average='binary')[source]

Bases: Metric

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

milliontrees.common.metrics.all_metrics.binary_logits_to_pred(logits)[source]
milliontrees.common.metrics.all_metrics.binary_logits_to_score(logits)[source]
milliontrees.common.metrics.all_metrics.mse_loss(out, targets)[source]
milliontrees.common.metrics.all_metrics.multiclass_logits_to_pred(logits)[source]

Converts multi-class logits into predictions.

This function takes a tensor of logits with shape (batch_size, …, n_classes) and computes predictions by applying argmax along the last dimension.

Parameters:

logits (Tensor) – A tensor of shape (batch_size, …, n_classes) representing multi-class logits.

Returns:

A tensor containing predicted class indices.

Return type:

Tensor

milliontrees.common.metrics.all_metrics.pseudolabel_binary_logits(logits, confidence_threshold)[source]

Applies a confidence threshold to binary logits and generates pseudo- labels.

Parameters:
  • logits (Tensor) – A tensor of shape (batch_size, n_tasks) representing binary logits. A positive value (>0) indicates a positive prediction for the corresponding (example, task).

  • confidence_threshold (float) – A threshold in the range [0,1] used to filter predictions.

Returns:

  • unlabeled_y_pred (Tensor): A filtered version of logits, discarding rows (examples)

    where no predictions exceed the confidence threshold.

  • unlabeled_y_pseudo (Tensor): A hard pseudo-labeled version of logits, where entries

    below the confidence threshold are set to NaN. Rows with no confident predictions are discarded.

  • pseudolabels_kept_frac (float): The fraction of (example, task) pairs that are not set

    to NaN or discarded.

  • mask (Tensor): A mask indicating which predictions meet the confidence threshold.

Return type:

tuple

milliontrees.common.metrics.all_metrics.pseudolabel_detection(preds, confidence_threshold)[source]

Filters detection predictions based on a confidence threshold.

Parameters:
  • preds (List[dict]) – A list of length batch_size, where each entry is a dictionary containing the following keys: - ‘boxes’ (Tensor): Bounding box coordinates. - ‘labels’ (Tensor): Class labels for detected objects. - ‘scores’ (Tensor): Confidence scores for each detection. - ‘losses’ (dict): An empty dictionary (not used).

  • confidence_threshold (float) – A threshold in the range [0,1] used to filter predictions.

Returns:

A filtered version of preds, where detections with confidence scores

below confidence_threshold are removed.

Return type:

List[dict]

milliontrees.common.metrics.all_metrics.pseudolabel_detection_discard_empty(preds, confidence_threshold)[source]

Filters detection predictions based on a confidence threshold and discards empty entries.

Parameters:
  • preds (List[dict]) – A list of length batch_size, where each entry is a dictionary containing the following keys: - ‘boxes’ (Tensor): Bounding box coordinates. - ‘labels’ (Tensor): Class labels for detected objects. - ‘scores’ (Tensor): Confidence scores for each detection. - ‘losses’ (dict): An empty dictionary (not used).

  • confidence_threshold (float) – A threshold in the range [0,1] used to filter predictions.

Returns:

A filtered version of preds, where detections with confidence scores

below confidence_threshold are removed. Entries with no remaining detections are discarded from the list.

Return type:

List[dict]

milliontrees.common.metrics.all_metrics.pseudolabel_identity(logits, confidence_threshold)[source]
milliontrees.common.metrics.all_metrics.pseudolabel_multiclass_logits(logits, confidence_threshold)[source]

Applies a confidence threshold to multi-class logits and generates pseudo-labels.

Parameters:
  • logits (Tensor) – A tensor of shape (batch_size, …, n_classes) representing multi-class logits.

  • confidence_threshold (float) – A threshold in the range [0,1] used to filter predictions.

Returns:

  • unlabeled_y_pred (Tensor): A filtered version of logits, discarding rows (examples)

    where no predictions exceed the confidence threshold.

  • unlabeled_y_pseudo (Tensor): A hard pseudo-labeled version of logits, where examples

    with confidence below the threshold are discarded.

  • pseudolabels_kept_frac (float): The fraction of examples retained after filtering.

  • mask (Tensor): A mask indicating which predictions meet the confidence threshold.

Return type:

tuple

milliontrees.common.metrics.loss module

class milliontrees.common.metrics.loss.ElementwiseLoss(loss_fn, name=None)[source]

Bases: ElementwiseMetric

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (float): Worst-case metric

class milliontrees.common.metrics.loss.Loss(loss_fn, name=None)[source]

Bases: Metric

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (float): Worst-case metric

class milliontrees.common.metrics.loss.MultiTaskLoss(loss_fn, name=None)[source]

Bases: MultiTaskMetric

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (float): Worst-case metric

milliontrees.common.metrics.metric module

class milliontrees.common.metrics.metric.ElementwiseMetric(name)[source]

Bases: Metric

Averages.

property agg_metric_field

The name of the key in the results dictionary returned by Metric.compute().

compute_element_wise(y_pred, y_true, return_dict=True)[source]

Computes element-wise metric.

Parameters:
  • y_pred (-) – Predicted targets or model output

  • y_true (-) – True targets

  • return_dict (-) – Whether to return the output as a dictionary or a tensor

Output (return_dict=False):
  • element_wise_metrics (Tensor): tensor of size (batch_size, )

Output (return_dict=True):
  • results (dict): Dictionary of results, mapping metric.name to element_wise_metrics

compute_flattened(y_pred, y_true, return_dict=True)[source]
worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

class milliontrees.common.metrics.metric.Metric(name)[source]

Bases: object

Parent class for metrics.

property agg_metric_field

The name of the key in the results dictionary returned by Metric.compute().

This should correspond to the aggregate metric computed on all of y_pred and y_true, in contrast to a group-wise evaluation.

compute(y_pred, y_true, return_dict=True)[source]

Computes metric.

This is a wrapper around _compute.
Args:
  • y_pred (Tensor): Predicted targets or model output

  • y_true (Tensor): True targets

  • return_dict (bool): Whether to return the output as a dictionary or a tensor

Output (return_dict=False):
  • metric (0-dim tensor): metric. If the inputs are empty, returns tensor(0.)

Output (return_dict=True):
  • results (dict): Dictionary of results, mapping metric.agg_metric_field to avg_metric

compute_group_wise(y_pred, y_true, g, n_groups, return_dict=True)[source]

Computes metrics for each group.

This is a wrapper around _compute.
Args:
  • y_pred (Tensor): Predicted targets or model output

  • y_true (Tensor): True targets

  • g (Tensor): groups

  • n_groups (int): number of groups

  • return_dict (bool): Whether to return the output as a dictionary or a tensor

Output (return_dict=False):
  • group_metrics (Tensor): tensor of size (n_groups, ) including the average metric for each group

  • group_counts (Tensor): tensor of size (n_groups, ) including the group count

  • worst_group_metric (0-dim tensor): worst-group metric

  • For empty inputs/groups, corresponding metrics are tensor(0.)

Output (return_dict=True):
  • results (dict): Dictionary of results

group_count_field(group_idx)[source]

The name of the keys corresponding to each group’s count in the results dictionary returned by Metric.compute_group_wise().

group_metric_field(group_idx)[source]

The name of the keys corresponding to individual group evaluations in the results dictionary returned by Metric.compute_group_wise().

property name

Metric name.

Used to name the key in the results dictionaries returned by the metric.

worst(metrics)[source]

Given a list/numpy array/Tensor of metrics, computes the worst-case metric.

Parameters:

metrics (-) – Metrics

Output:
  • worst_metric (0-dim tensor): Worst-case metric

property worst_group_metric_field

The name of the keys corresponding to the worst-group metric in the results dictionary returned by Metric.compute_group_wise().

class milliontrees.common.metrics.metric.MultiTaskMetric(name)[source]

Bases: Metric

compute_flattened(y_pred, y_true, return_dict=True)[source]

Module contents