Training models¶
MillionTrees is a dataset package, it is designed to be general and flexible. We recommend using pytorch-lightning to train models for maximum reproducibility. Imagine a simple object detection model that predicts the bounding boxes of trees in an image. Of course users are welcome to use any other framework or model, but this is a simple example to get started.
Data setup¶
from milliontrees.common.data_loaders import get_train_loader
from milliontrees.datasets.TreeBoxes import TreeBoxesDataset
# Download the data; this will take a while
dataset = TreeBoxesDataset(download=True)
train_dataset = dataset.get_subset("train")
train_loader = get_train_loader("standard", train_dataset, batch_size=2)
Model Definition¶
# Create a simple PyTorch Lightning object detection model
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision
class ObjectDetectionModel(pl.LightningModule):
def __init__(self):
super().__init__()
# Use a pre-trained Faster R-CNN model
self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# Modify the box predictor for our use case
num_classes = 2 # Background + Tree
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
metadata, images, targets = batch
loss_dict = self.model(images, targets)
loss = sum(loss for loss in loss_dict.values())
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.model.parameters(), lr=0.005, momentum=0.9)
# Initialize model and trainer
model = ObjectDetectionModel()
trainer = pl.Trainer(max_epochs=10, accelerator='auto')
# Create data loader
train_dataloader = get_train_loader("standard", train_dataset, batch_size=2)
trainer.fit(model, train_dataloader)