milliontrees.common.metrics package¶
Submodules¶
milliontrees.common.metrics.all_metrics module¶
- class milliontrees.common.metrics.all_metrics.Accuracy(prediction_fn=None, name=None)[source]¶
Bases:
ElementwiseMetric
- class milliontrees.common.metrics.all_metrics.CountingError(score_threshold=0.1, name=None, geometry_name='y')[source]¶
Bases:
ElementwiseMetricMean Absolute Error between ground truth and predicted detection counts.
Calculates MAE between the number of detections in ground truth vs predictions for each sample in the batch.
- class milliontrees.common.metrics.all_metrics.DetectionAccuracy(iou_threshold=0.4, score_threshold=0.1, name=None, geometry_name='boxes', metric='accuracy')[source]¶
Bases:
ElementwiseMetricPer-image detection recall or accuracy with greedy 1:1 IoU matching.
- class milliontrees.common.metrics.all_metrics.DetectionMAP(geometry_name='y', score_threshold=0.1, iou_type='bbox', name=None)[source]¶
Bases:
MetricMean Average Precision for object detection using torchmetrics.
Supports bounding boxes (iou_type=”bbox”) and instance segmentation masks (iou_type=”segm”). Single-class: all labels are normalised to 0 so that predictions always match the ground-truth class regardless of the model’s raw label output.
- property agg_metric_field¶
The name of the key in the results dictionary returned by Metric.compute().
This should correspond to the aggregate metric computed on all of y_pred and y_true, in contrast to a group-wise evaluation.
- class milliontrees.common.metrics.all_metrics.DummyMetric(prediction_fn=None, name=None)[source]¶
Bases:
MetricFor testing purposes.
This Metric always returns -1.
- class milliontrees.common.metrics.all_metrics.F1(prediction_fn=None, name=None, average='binary')[source]¶
Bases:
Metric
- class milliontrees.common.metrics.all_metrics.KeypointAccuracy(distance_threshold: float = 0.02, score_threshold: float = 0.1, name: str | None = None, geometry_name: str = 'y', image_size: int = 448)[source]¶
Bases:
ElementwiseMetricKeypoint accuracy for a one-class detector.
The
distance_thresholdis interpreted as a normalized distance with respect to the image size rather than raw pixels. For a square image of side lengthimage_size, the effective pixel threshold ispixel_threshold = distance_threshold * image_size
This makes the metric less sensitive to the absolute crop size while still behaving like a fixed-radius matching rule in pixel space.
- class milliontrees.common.metrics.all_metrics.KeypointMergeCommissionMetric(distance_threshold: float = 0.02, score_threshold: float = 0.1, geometry_name: str = 'y', image_size: int = 448, name: str | None = None)[source]¶
Bases:
ElementwiseMetricFraction of predictions within
max_distanceof two or more GT points.
- class milliontrees.common.metrics.all_metrics.MSE(name=None)[source]¶
Bases:
ElementwiseLoss
- class milliontrees.common.metrics.all_metrics.MaskAccuracy(iou_threshold=0.4, score_threshold=0.1, name=None, geometry_name='masks', metric='accuracy')[source]¶
Bases:
ElementwiseMetricPer-image mask recall or accuracy with greedy 1:1 mask IoU matching.
- class milliontrees.common.metrics.all_metrics.MaskAwareDetectionPrecision(iou_threshold=0.4, score_threshold=0.1, tree_fraction_threshold=0.5, require_tree_coverage_mask=False, name=None, geometry_name='boxes', tree_coverage_key='tree_coverage_mask')[source]¶
Bases:
ElementwiseMetricPrecision metric that avoids penalizing predictions on unannotated tree regions.
Unmatched predictions are excluded from false positives when enough of their box area overlaps tree pixels from
tree_coverage_mask.
- class milliontrees.common.metrics.all_metrics.MaskAwareKeypointPrecision(distance_threshold: float = 0.02, score_threshold: float = 0.1, tree_fraction_threshold: float = 0.5, require_tree_coverage_mask: bool = False, name: str | None = None, geometry_name: str = 'y', tree_coverage_key: str = 'tree_coverage_mask', image_size: int = 448)[source]¶
Bases:
ElementwiseMetricPrecision for point detection that ignores unmatched points on tree-covered pixels.
- class milliontrees.common.metrics.all_metrics.MaskAwareMaskPrecision(iou_threshold=0.4, score_threshold=0.1, tree_fraction_threshold=0.5, require_tree_coverage_mask=False, name=None, geometry_name='masks', tree_coverage_key='tree_coverage_mask')[source]¶
Bases:
ElementwiseMetricPrecision for mask detection that ignores unmatched masks on tree-covered pixels.
- class milliontrees.common.metrics.all_metrics.MergeCommissionMetric(iou_threshold: float = 0.4, score_threshold: float = 0.1, geometry_name: str = 'y', modality: str = 'bbox', name: str | None = None)[source]¶
Bases:
ElementwiseMetricFraction of predictions with IoU >
iou_thresholdagainst two or more GT objects.
- class milliontrees.common.metrics.all_metrics.MultiTaskAccuracy(prediction_fn=None, name=None)[source]¶
Bases:
MultiTaskMetric
- class milliontrees.common.metrics.all_metrics.MultiTaskAveragePrecision(prediction_fn=None, name=None, average='macro')[source]¶
Bases:
MultiTaskMetric
- class milliontrees.common.metrics.all_metrics.PrecisionAtRecall(threshold, score_fn=None, name=None)[source]¶
Bases:
MetricGiven a specific model threshold, determine the precision score achieved.
- class milliontrees.common.metrics.all_metrics.Recall(prediction_fn=None, name=None, average='binary')[source]¶
Bases:
Metric
- milliontrees.common.metrics.all_metrics.multiclass_logits_to_pred(logits)[source]¶
Converts multi-class logits into predictions.
This function takes a tensor of logits with shape (batch_size, …, n_classes) and computes predictions by applying argmax along the last dimension.
- Parameters:
logits (Tensor) – A tensor of shape (batch_size, …, n_classes) representing multi-class logits.
- Returns:
A tensor containing predicted class indices.
- Return type:
Tensor
- milliontrees.common.metrics.all_metrics.pseudolabel_binary_logits(logits, confidence_threshold)[source]¶
Applies a confidence threshold to binary logits and generates pseudo- labels.
- Parameters:
logits (Tensor) – A tensor of shape (batch_size, n_tasks) representing binary logits. A positive value (>0) indicates a positive prediction for the corresponding (example, task).
confidence_threshold (float) – A threshold in the range [0,1] used to filter predictions.
- Returns:
- unlabeled_y_pred (Tensor): A filtered version of logits, discarding rows (examples)
where no predictions exceed the confidence threshold.
- unlabeled_y_pseudo (Tensor): A hard pseudo-labeled version of logits, where entries
below the confidence threshold are set to NaN. Rows with no confident predictions are discarded.
- pseudolabels_kept_frac (float): The fraction of (example, task) pairs that are not set
to NaN or discarded.
mask (Tensor): A mask indicating which predictions meet the confidence threshold.
- Return type:
tuple
- milliontrees.common.metrics.all_metrics.pseudolabel_detection(preds, confidence_threshold)[source]¶
Filters detection predictions based on a confidence threshold.
- Parameters:
preds (List[dict]) – A list of length batch_size, where each entry is a dictionary containing the following keys: - ‘boxes’ (Tensor): Bounding box coordinates. - ‘labels’ (Tensor): Class labels for detected objects. - ‘scores’ (Tensor): Confidence scores for each detection. - ‘losses’ (dict): An empty dictionary (not used).
confidence_threshold (float) – A threshold in the range [0,1] used to filter predictions.
- Returns:
- A filtered version of preds, where detections with confidence scores
below confidence_threshold are removed.
- Return type:
List[dict]
- milliontrees.common.metrics.all_metrics.pseudolabel_detection_discard_empty(preds, confidence_threshold)[source]¶
Filters detection predictions based on a confidence threshold and discards empty entries.
- Parameters:
preds (List[dict]) – A list of length batch_size, where each entry is a dictionary containing the following keys: - ‘boxes’ (Tensor): Bounding box coordinates. - ‘labels’ (Tensor): Class labels for detected objects. - ‘scores’ (Tensor): Confidence scores for each detection. - ‘losses’ (dict): An empty dictionary (not used).
confidence_threshold (float) – A threshold in the range [0,1] used to filter predictions.
- Returns:
- A filtered version of preds, where detections with confidence scores
below confidence_threshold are removed. Entries with no remaining detections are discarded from the list.
- Return type:
List[dict]
- milliontrees.common.metrics.all_metrics.pseudolabel_multiclass_logits(logits, confidence_threshold)[source]¶
Applies a confidence threshold to multi-class logits and generates pseudo-labels.
- Parameters:
logits (Tensor) – A tensor of shape (batch_size, …, n_classes) representing multi-class logits.
confidence_threshold (float) – A threshold in the range [0,1] used to filter predictions.
- Returns:
- unlabeled_y_pred (Tensor): A filtered version of logits, discarding rows (examples)
where no predictions exceed the confidence threshold.
- unlabeled_y_pseudo (Tensor): A hard pseudo-labeled version of logits, where examples
with confidence below the threshold are discarded.
pseudolabels_kept_frac (float): The fraction of examples retained after filtering.
mask (Tensor): A mask indicating which predictions meet the confidence threshold.
- Return type:
tuple
milliontrees.common.metrics.loss module¶
- class milliontrees.common.metrics.loss.ElementwiseLoss(loss_fn, name=None)[source]¶
Bases:
ElementwiseMetric
- class milliontrees.common.metrics.loss.MultiTaskLoss(loss_fn, name=None)[source]¶
Bases:
MultiTaskMetric
milliontrees.common.metrics.metric module¶
- class milliontrees.common.metrics.metric.ElementwiseMetric(name)[source]¶
Bases:
MetricAverages.
- property agg_metric_field¶
The name of the key in the results dictionary returned by Metric.compute().
- compute_element_wise(y_pred, y_true, return_dict=True)[source]¶
Computes element-wise metric.
- Parameters:
y_pred (-) – Predicted targets or model output
y_true (-) – True targets
return_dict (-) – Whether to return the output as a dictionary or a tensor
- Output (return_dict=False):
element_wise_metrics (Tensor): tensor of size (batch_size, )
- Output (return_dict=True):
results (dict): Dictionary of results, mapping metric.name to element_wise_metrics
- class milliontrees.common.metrics.metric.Metric(name)[source]¶
Bases:
objectParent class for metrics.
- property agg_metric_field¶
The name of the key in the results dictionary returned by Metric.compute().
This should correspond to the aggregate metric computed on all of y_pred and y_true, in contrast to a group-wise evaluation.
- compute(y_pred, y_true, return_dict=True)[source]¶
Computes metric.
- This is a wrapper around _compute.
- Args:
y_pred (Tensor): Predicted targets or model output
y_true (Tensor): True targets
return_dict (bool): Whether to return the output as a dictionary or a tensor
- Output (return_dict=False):
metric (0-dim tensor): metric. If the inputs are empty, returns tensor(0.)
- Output (return_dict=True):
results (dict): Dictionary of results, mapping metric.agg_metric_field to avg_metric
- compute_group_wise(y_pred, y_true, g, n_groups, return_dict=True)[source]¶
Computes metrics for each group.
- This is a wrapper around _compute.
- Args:
y_pred (Tensor): Predicted targets or model output
y_true (Tensor): True targets
g (Tensor): groups
n_groups (int): number of groups
return_dict (bool): Whether to return the output as a dictionary or a tensor
- Output (return_dict=False):
group_metrics (Tensor): tensor of size (n_groups, ) including the average metric for each group
group_counts (Tensor): tensor of size (n_groups, ) including the group count
worst_group_metric (0-dim tensor): worst-group metric
For empty inputs/groups, corresponding metrics are tensor(0.)
- Output (return_dict=True):
results (dict): Dictionary of results
- group_count_field(group_idx)[source]¶
The name of the keys corresponding to each group’s count in the results dictionary returned by Metric.compute_group_wise().
- group_metric_field(group_idx)[source]¶
The name of the keys corresponding to individual group evaluations in the results dictionary returned by Metric.compute_group_wise().
- property name¶
Metric name.
Used to name the key in the results dictionaries returned by the metric.
- worst(metrics)[source]¶
Given a list/numpy array/Tensor of metrics, computes the worst-case metric.
- Parameters:
metrics (-) – Metrics
- Output:
worst_metric (0-dim tensor): Worst-case metric
- property worst_group_metric_field¶
The name of the keys corresponding to the worst-group metric in the results dictionary returned by Metric.compute_group_wise().