Training Models¶
MillionTrees is a dataset package designed to be general and flexible. We recommend using PyTorch Lightning to train models for maximum reproducibility. Below is a complete example using Faster R-CNN for bounding box detection. Users are welcome to use any framework or model.
Batch Format¶
All MillionTrees dataloaders return batches as (metadata, images, targets):
metadata:
Tensor[B, 2]—[filename_id, source_id]per imageimages:
Tensor[B, 3, H, W]— batch of RGB imagestargets:
list[dict]— one dict per image with keys:
Task |
|
|
Description |
|---|---|---|---|
TreeBoxes |
|
|
|
TreePoints |
|
|
|
TreePolygons |
|
|
Binary masks per instance |
Data Setup¶
from milliontrees.common.data_loaders import get_train_loader, get_eval_loader
from milliontrees.datasets.TreeBoxes import TreeBoxesDataset
# Use mini=True for development; remove for full training
dataset = TreeBoxesDataset(download=True, mini=True)
train_dataset = dataset.get_subset("train")
val_dataset = dataset.get_subset("test")
train_loader = get_train_loader("standard", train_dataset, batch_size=2)
val_loader = get_eval_loader("standard", val_dataset, batch_size=2)
Model Definition¶
import pytorch_lightning as pl
import torch
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
class TreeDetector(pl.LightningModule):
def __init__(self, num_classes=2, lr=5e-3):
super().__init__()
self.save_hyperparameters()
self.model = fasterrcnn_resnet50_fpn_v2(weights=FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT)
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
def training_step(self, batch, batch_idx):
metadata, images, targets = batch
formatted = [
{"boxes": t["y"], "labels": torch.ones(len(t["y"]), dtype=torch.int64, device=images.device)}
for t in targets
]
loss_dict = self.model(images, formatted)
loss = sum(loss_dict.values())
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9)
model = TreeDetector()
trainer = pl.Trainer(max_epochs=10, accelerator="auto")
trainer.fit(model, train_loader)
Note: Torchvision detection models expect labels as integers starting from 1 (0 is background). The MillionTrees target dict uses "y" for coordinates and "labels" for class labels, so you need to reformat targets as shown above.
Evaluation¶
After training, use the MillionTrees evaluation API to compute metrics:
from milliontrees.common.eval import Evaluator
evaluator = dataset.eval # built-in evaluator
model.eval()
all_predictions, all_targets = [], []
with torch.no_grad():
for metadata, images, targets in val_loader:
outputs = model.model(images)
for output, target in zip(outputs, targets):
pred = {
"y": output["boxes"],
"labels": output["labels"],
"scores": output["scores"],
}
all_predictions.append(pred)
all_targets.append(target)
results = evaluator(all_predictions, all_targets)
print(results)