Examples¶
Python Examples¶
Baseline Boxes Example¶
1import os
2import argparse
3import warnings
4from typing import List, Tuple
5
6import numpy as np
7import pandas as pd
8import torch
9
10from deepforest import main as df_main
11from deepforest.utilities import read_file, format_geometry
12import matplotlib.pyplot as plt
13import matplotlib.patches as patches
14
15from milliontrees import get_dataset
16from milliontrees.common.data_loaders import get_eval_loader
17
18
19def format_deepforest_predictions(
20 images: np.ndarray,
21 metadata: torch.Tensor,
22 targets: List[dict],
23 model: "df_main.deepforest",
24 dataset,
25 batch_index: int,
26) -> Tuple[List[dict], List[pd.DataFrame]]:
27 """Run DeepForest on a batch and convert to MillionTrees format.
28
29 Returns a tuple of:
30 - list of y_pred dicts for MillionTrees evaluation
31 - list of original prediction DataFrames (for plotting)
32 """
33 warnings.filterwarnings("ignore")
34
35 images_tensor = torch.tensor(images)
36 predictions = model.predict_step(images_tensor, batch_index)
37
38 batch_y_pred: List[dict] = []
39 formatted_predictions: List[pd.DataFrame] = []
40
41 for image_metadata, pred, image_targets, image in zip(
42 metadata, predictions, targets, images_tensor):
43 basename = dataset._filename_id_to_code[int(image_metadata[0])]
44
45 if pred is None or len(pred["boxes"]) == 0:
46 y_pred = {
47 "y": torch.zeros((0, 4), dtype=torch.float32),
48 "labels": torch.zeros((0,), dtype=torch.int64),
49 "scores": torch.zeros((0,), dtype=torch.float32),
50 }
51 formatted_pred = pd.DataFrame(
52 columns=["xmin", "ymin", "xmax", "ymax", "score", "label"]) # empty
53 formatted_pred.root_dir = os.path.join(dataset._data_dir._str,
54 "images")
55 formatted_pred["image_path"] = basename
56 else:
57 formatted_pred = format_geometry(pred)
58 formatted_pred.root_dir = os.path.join(dataset._data_dir._str,
59 "images")
60 formatted_pred["image_path"] = basename
61
62 y_pred = {
63 "y": torch.tensor(
64 formatted_pred[["xmin", "ymin", "xmax",
65 "ymax"]].values.astype("float32")),
66 "labels": torch.tensor(formatted_pred.label.values.astype(
67 np.int64)),
68 "scores": torch.tensor(
69 formatted_pred.score.values.astype("float32")),
70 }
71
72 batch_y_pred.append(y_pred)
73 formatted_predictions.append(formatted_pred)
74
75 return batch_y_pred, formatted_predictions
76
77
78def _draw_boxes(ax, boxes: np.ndarray, color: str):
79 for box in boxes:
80 if len(box) != 4:
81 continue
82 xmin, ymin, xmax, ymax = box
83 width = max(0.0, float(xmax) - float(xmin))
84 height = max(0.0, float(ymax) - float(ymin))
85 rect = patches.Rectangle((float(xmin), float(ymin)), width, height,
86 linewidth=0.8,
87 edgecolor=color,
88 facecolor='none',
89 alpha=0.8)
90 ax.add_patch(rect)
91
92
93def save_gallery(thumbnails: List[dict], rows: int, cols: int, dpi: int,
94 output_path: str):
95 if len(thumbnails) == 0:
96 return
97 n = min(len(thumbnails), rows * cols)
98 fig, axes = plt.subplots(rows,
99 cols,
100 figsize=(cols * 3, rows * 3),
101 dpi=dpi)
102 axes = np.atleast_2d(axes)
103
104 for idx in range(rows * cols):
105 r = idx // cols
106 c = idx % cols
107 ax = axes[r, c]
108 if idx < n:
109 thumb = thumbnails[idx]
110 ax.imshow(thumb["image"], interpolation='nearest')
111 _draw_boxes(ax, thumb.get("gt_boxes", np.zeros((0, 4))),
112 color='orange')
113 _draw_boxes(ax, thumb.get("pred_boxes", np.zeros((0, 4))),
114 color='royalblue')
115 ax.set_title(thumb.get("title", ""), fontsize=8)
116 ax.axis('off')
117
118 plt.tight_layout(pad=0.2)
119 os.makedirs(os.path.dirname(output_path), exist_ok=True)
120 fig.savefig(output_path, bbox_inches='tight')
121 plt.close(fig)
122
123
124def main():
125 parser = argparse.ArgumentParser(
126 description="Run baseline DeepForest evaluation on TreeBoxes.")
127 parser.add_argument(
128 "--root-dir",
129 type=str,
130 default=os.environ.get("MT_ROOT",
131 "/orange/ewhite/web/public/MillionTrees"),
132 help="Dataset root directory",
133 )
134 parser.add_argument("--batch-size", type=int, default=12)
135 parser.add_argument(
136 "--plot-interval",
137 type=int,
138 default=250,
139 help="Plot every Nth image; set 0 to disable plotting",
140 )
141 parser.add_argument("--gallery-rows", type=int, default=2)
142 parser.add_argument("--gallery-cols", type=int, default=3)
143 parser.add_argument("--gallery-dpi", type=int, default=72)
144 parser.add_argument("--output-dir", type=str, default=None)
145 parser.add_argument("--max-batches", type=int, default=None)
146 parser.add_argument("--mini", action="store_true", help="Use mini datasets for fast dev")
147 parser.add_argument("--download", action="store_true", help="Download dataset if missing")
148 parser.add_argument("--split-scheme",
149 type=str,
150 default="random",
151 choices=["random", "zeroshot", "crossgeometry"],
152 help="Dataset split scheme")
153 args = parser.parse_args()
154
155 # Load model
156 model = df_main.deepforest()
157 model.load_model("weecology/deepforest-tree")
158 model.eval()
159
160 # Load dataset
161 box_dataset = get_dataset("TreeBoxes",
162 download=args.download,
163 mini=args.mini,
164 root_dir=args.root_dir,
165 split_scheme=args.split_scheme)
166 test_subset = box_dataset.get_subset("test")
167 test_loader = get_eval_loader("standard",
168 test_subset,
169 batch_size=args.batch_size)
170
171 print(f"There are {len(test_loader)} batches in the test loader")
172
173 all_y_pred: List[dict] = []
174 all_y_true: List[dict] = []
175
176 batch_index = 0
177 thumbnails: List[dict] = []
178 max_thumbnails = args.gallery_rows * args.gallery_cols if args.output_dir else 0
179 for batch in test_loader:
180 metadata, images, targets = batch
181 mt_preds, df_preds = format_deepforest_predictions(images, metadata,
182 targets, model,
183 box_dataset,
184 batch_index)
185
186 for image_metadata, y_pred, pred, image_targets, image in zip(
187 metadata, mt_preds, df_preds, targets, images):
188
189 # Collect thumbnails for gallery
190 if args.output_dir and len(thumbnails) < max_thumbnails:
191 basename = (pred.image_path.unique()[0]
192 if isinstance(pred, pd.DataFrame)
193 and len(pred) > 0 else "")
194 image_np = (image.permute(1, 2, 0).numpy() * 255).clip(
195 0, 255).astype("uint8")
196 pred_boxes = y_pred.get("y", torch.zeros((0, 4))).detach().cpu().numpy()
197 gt_boxes = image_targets["y"].detach().cpu().numpy()
198 recall = box_dataset.metrics["recall"]._recall(
199 image_targets["y"], y_pred.get("y", torch.zeros((0, 4))),
200 iou_threshold=0.4)
201 title = f"{basename} R@0.4={float(recall):.2f}"
202 thumbnails.append({
203 "image": image_np,
204 "pred_boxes": pred_boxes,
205 "gt_boxes": gt_boxes,
206 "title": title
207 })
208
209 all_y_pred.append(y_pred)
210 all_y_true.append(image_targets)
211 batch_index += 1
212
213 if args.max_batches is not None and batch_index >= args.max_batches:
214 break
215
216 results, results_str = box_dataset.eval(all_y_pred, all_y_true,
217 test_subset.metadata_array[:len(all_y_true)])
218 print(results_str)
219
220 if args.output_dir:
221 os.makedirs(args.output_dir, exist_ok=True)
222 with open(os.path.join(args.output_dir, "results_boxes.txt"), "w",
223 encoding="utf-8") as f:
224 f.write(results_str)
225 # Save a compact gallery image
226 gallery_path = os.path.join(args.output_dir, "gallery_boxes.png")
227 save_gallery(thumbnails, args.gallery_rows, args.gallery_cols,
228 args.gallery_dpi, gallery_path)
229
230
231if __name__ == "__main__":
232 main()
233
234
Baseline Points Example¶
1import os
2import argparse
3import warnings
4from typing import List, Tuple
5
6import numpy as np
7import pandas as pd
8import torch
9
10from deepforest import main as df_main
11from deepforest.utilities import read_file, format_geometry
12from deepforest.visualize import plot_results
13import geopandas as gpd
14
15from milliontrees import get_dataset
16from milliontrees.common.data_loaders import get_eval_loader
17
18
19def _map_labels_to_int(values: pd.Series, model: "df_main.deepforest") -> np.ndarray:
20 if values.dtype == object:
21 return values.apply(lambda x: model.label_dict.get(x, 0)).values.astype(np.int64)
22 return values.values.astype(np.int64)
23
24
25def format_deepforest_predictions(
26 images: np.ndarray,
27 metadata: torch.Tensor,
28 targets: List[dict],
29 model: "df_main.deepforest",
30 dataset,
31 batch_index: int,
32) -> Tuple[List[dict], List[pd.DataFrame]]:
33 """Run DeepForest on a batch and convert to MillionTrees point format."""
34 warnings.filterwarnings("ignore")
35
36 images_tensor = torch.tensor(images)
37 predictions = model.predict_step(images_tensor, batch_index)
38
39 batch_y_pred: List[dict] = []
40 formatted_predictions: List[pd.DataFrame] = []
41
42 for image_metadata, image_pred, image_targets, image in zip(
43 metadata, predictions, targets, images_tensor):
44 basename = dataset._filename_id_to_code[int(image_metadata[0])]
45
46 if len(image_pred["boxes"]) == 0:
47 y_pred = {
48 "y": torch.zeros((0, 2), dtype=torch.float32),
49 "labels": torch.zeros((0,), dtype=torch.int64),
50 "scores": torch.zeros((0,), dtype=torch.float32),
51 }
52 formatted_pred = pd.DataFrame(columns=["x", "y", "score", "label"]) # empty
53 else:
54 formatted_pred = format_geometry(image_pred)
55 formatted_pred["image_path"] = basename
56 formatted_pred = read_file(
57 formatted_pred,
58 root_dir=os.path.join(dataset._data_dir._str, "images"),
59 image_path=basename,
60 )
61
62 # Convert boxes to centroids
63 formatted_pred["geometry"] = gpd.GeoSeries(formatted_pred["geometry"]).centroid
64 formatted_pred[["x", "y"]] = formatted_pred["geometry"].apply(lambda g: pd.Series([g.x, g.y]))
65 y_pred = {
66 "y": torch.tensor(formatted_pred[["x", "y"]].values.astype("float32")),
67 "labels": torch.tensor(_map_labels_to_int(formatted_pred.label, model)),
68 "scores": torch.tensor(formatted_pred.score.values.astype("float32")),
69 }
70
71 batch_y_pred.append(y_pred)
72 formatted_predictions.append(formatted_pred)
73
74 return batch_y_pred, formatted_predictions
75
76
77def plot_eval_result(
78 y_pred: dict,
79 pred_df: pd.DataFrame,
80 image_targets: dict,
81 image_tensor: torch.Tensor,
82 dataset,
83 batch_index: int,
84 output_dir: str = None,
85):
86 basename = (
87 pred_df.image_path.unique()[0]
88 if isinstance(pred_df, pd.DataFrame) and len(pred_df) > 0
89 else "empty"
90 )
91
92 # Ground truth
93 gt_df = pd.DataFrame(image_targets["y"].numpy(), columns=["x", "y"])
94 gt_df["image_path"] = basename
95 gt_df = read_file(
96 gt_df,
97 root_dir=os.path.join(dataset._data_dir._str, "images"),
98 image_path=basename,
99 label="Tree",
100 )
101 gt_df["label"] = "Tree"
102 gt_df["score"] = 1
103
104 # Predictions
105 if isinstance(pred_df, pd.DataFrame) and len(pred_df) > 0:
106 pred_vis_df = read_file(
107 pred_df,
108 root_dir=os.path.join(dataset._data_dir._str, "images"),
109 image_path=basename,
110 label="Tree",
111 )
112 pred_vis_df["geometry"] = gpd.GeoSeries(pred_vis_df["geometry"]).centroid
113 pred_vis_df[["x", "y"]] = pred_vis_df["geometry"].apply(lambda g: pd.Series([g.x, g.y]))
114 if "label" not in pred_vis_df.columns:
115 pred_vis_df["label"] = "Tree"
116 else:
117 pred_vis_df = pred_df
118
119 # Image channel-last, 0-255
120 image = image_tensor.permute(1, 2, 0).numpy() * 255
121
122 pred_vis_df.root_dir = os.path.join(dataset._data_dir._str, "images")
123 fig = plot_results(pred_vis_df, gt_df, image=image.astype("int32"))
124 if output_dir:
125 os.makedirs(output_dir, exist_ok=True)
126 out_path = os.path.join(output_dir, f"{batch_index:06d}_{basename}.png")
127 fig.savefig(out_path, dpi=150, bbox_inches="tight")
128
129def main():
130 parser = argparse.ArgumentParser(description="Run baseline DeepForest on TreePoints using centroid conversion.")
131 parser.add_argument(
132 "--root-dir",
133 type=str,
134 default=os.environ.get("MT_ROOT", "/orange/ewhite/web/public/MillionTrees/"),
135 help="Dataset root directory",
136 )
137 parser.add_argument("--batch-size", type=int, default=32)
138 parser.add_argument("--plot-interval", type=int, default=1000, help="Plot every Nth image; set 0 to disable plotting")
139 parser.add_argument("--output-dir", type=str, default=None)
140 parser.add_argument("--max-batches", type=int, default=None)
141 parser.add_argument("--mini", action="store_true", help="Use mini datasets for fast dev")
142 parser.add_argument("--download", action="store_true", help="Download dataset if missing")
143 parser.add_argument("--split-scheme",
144 type=str,
145 default="random",
146 choices=["random", "zeroshot", "crossgeometry"],
147 help="Dataset split scheme")
148 args = parser.parse_args()
149
150 # Load model
151 model = df_main.deepforest()
152 model.load_model("weecology/deepforest-tree")
153 model.eval()
154
155 # Load dataset
156 point_dataset = get_dataset("TreePoints",
157 root_dir=args.root_dir,
158 mini=args.mini,
159 download=args.download,
160 split_scheme=args.split_scheme)
161 test_subset = point_dataset.get_subset("test")
162 test_loader = get_eval_loader("standard", test_subset, batch_size=args.batch_size)
163
164 print(f"There are {len(test_loader)} batches in the test loader")
165
166 all_y_pred: List[dict] = []
167 all_y_true: List[dict] = []
168
169 batch_index = 0
170 for batch in test_loader:
171 metadata, images, targets = batch
172 mt_preds, df_preds = format_deepforest_predictions(images, metadata, targets, model, point_dataset, batch_index)
173
174 for image_metadata, y_pred, pred, image_targets, image in zip(metadata, mt_preds, df_preds, targets, images):
175 if args.plot_interval and args.plot_interval > 0 and (batch_index % args.plot_interval == 0):
176 plot_eval_result(y_pred, pred, image_targets, image, point_dataset, batch_index, args.output_dir)
177
178 all_y_pred.append(y_pred)
179 all_y_true.append(image_targets)
180 batch_index += 1
181
182 if args.max_batches is not None and batch_index >= args.max_batches:
183 break
184
185 # Evaluate using dataset's metric implementation
186 results, results_str = point_dataset.eval(all_y_pred, all_y_true, test_subset.metadata_array[:len(all_y_true)])
187 print(results_str)
188
189 if args.output_dir:
190 os.makedirs(args.output_dir, exist_ok=True)
191 with open(os.path.join(args.output_dir, "results_points.txt"), "w", encoding="utf-8") as f:
192 f.write(results_str)
193
194
195if __name__ == "__main__":
196 main()
197
198
Baseline Polygons Example¶
1import os
2import argparse
3import warnings
4from typing import List, Tuple
5
6import numpy as np
7import pandas as pd
8import torch
9
10from deepforest import main as df_main
11from deepforest.utilities import read_file, format_geometry
12from deepforest.visualize import plot_results
13
14from milliontrees import get_dataset
15from milliontrees.common.data_loaders import get_eval_loader
16
17
18def format_deepforest_predictions(
19 images: np.ndarray,
20 metadata: torch.Tensor,
21 targets: List[dict],
22 model, # Remove the problematic type hint
23 dataset,
24 batch_index: int,
25) -> Tuple[List[dict], List[pd.DataFrame]]:
26 """
27 Run DeepForest on a batch and convert to MillionTrees format.
28
29 Args:
30 images: Input images as numpy array
31 metadata: Tensor containing metadata information
32 targets: List of target dictionaries
33 model: DeepForest model instance
34 dataset: Dataset instance
35 batch_index: Index of the current batch
36
37 Returns:
38 Tuple containing predictions in MillionTrees format and formatted DataFrames
39 """
40 warnings.filterwarnings("ignore")
41
42 images_tensor = torch.tensor(images)
43 predictions = model.predict_step(images_tensor, batch_index)
44
45 batch_y_pred: List[dict] = []
46 formatted_predictions: List[pd.DataFrame] = []
47
48 for image_metadata, pred, image_targets, image in zip(
49 metadata, predictions, targets, images_tensor
50 ):
51 basename = dataset._filename_id_to_code[int(image_metadata[0])]
52
53 if pred is None or len(pred["boxes"]) == 0:
54 y_pred = {
55 "y": torch.zeros((0, 4), dtype=torch.float32),
56 "labels": torch.zeros((0,), dtype=torch.int64),
57 "scores": torch.zeros((0,), dtype=torch.float32),
58 }
59 formatted_pred = pd.DataFrame(
60 columns=["xmin", "ymin", "xmax", "ymax", "score", "label"]
61 )
62 formatted_pred.root_dir = os.path.join(dataset._data_dir._str, "images")
63 formatted_pred["image_path"] = basename
64 else:
65 formatted_pred = format_geometry(pred)
66 formatted_pred.root_dir = os.path.join(dataset._data_dir._str, "images")
67 formatted_pred["image_path"] = basename
68
69 y_pred = {
70 "y": torch.tensor(
71 formatted_pred[["xmin", "ymin", "xmax", "ymax"]].values.astype("float32")
72 ),
73 "labels": torch.tensor(formatted_pred.label.values.astype(np.int64)),
74 "scores": torch.tensor(formatted_pred.score.values.astype("float32")),
75 }
76
77 batch_y_pred.append(y_pred)
78 formatted_predictions.append(formatted_pred)
79
80 return batch_y_pred, formatted_predictions
81
82
83def plot_eval_result(
84 y_pred: dict,
85 pred_df: pd.DataFrame,
86 image_targets: dict,
87 image_tensor: torch.Tensor,
88 dataset,
89 batch_index: int,
90 output_dir: str = None,
91):
92 basename = pred_df.image_path.unique()[0] if len(pred_df) > 0 else "empty"
93
94 # Ground truth
95 gt_df = pd.DataFrame(
96 image_targets["bboxes"],
97 columns=["xmin", "ymin", "xmax", "ymax"],
98 )
99 gt_df["image_path"] = basename
100 gt_df = read_file(
101 gt_df,
102 root_dir=os.path.join(dataset._data_dir._str, "images"),
103 image_path=basename,
104 label="Tree",
105 )
106 gt_df["label"] = "Tree"
107
108 # Predictions
109 pred_vis_df = read_file(
110 pred_df,
111 root_dir=os.path.join(dataset._data_dir._str, "images"),
112 image_path=basename,
113 ) if len(pred_df) > 0 else pred_df
114 if len(pred_df) > 0 and "label" not in pred_vis_df.columns:
115 pred_vis_df["label"] = "Tree"
116
117 # Image channel-last, 0-255
118 image = image_tensor.permute(1, 2, 0).numpy() * 255
119
120 # Simple recall example for logging
121 recall = dataset.metrics["recall"]._recall(
122 image_targets["bboxes"],
123 y_pred.get("bboxes", torch.zeros((0, 4))),
124 iou_threshold=0.4
125 )
126
127 # Plot
128 try:
129 fig = plot_results(pred_vis_df, gt_df, image=image.astype("int32"))
130 if output_dir:
131 os.makedirs(output_dir, exist_ok=True)
132 out_path = os.path.join(output_dir, f"{batch_index:06d}_{basename}.png")
133 fig.savefig(out_path, dpi=150, bbox_inches="tight")
134 except Exception:
135 pass
136
137 print(f"Image: {basename}, idx {batch_index}, Recall@0.4: {float(recall):.2f}")
138
139
140def main():
141 parser = argparse.ArgumentParser(
142 description="Run baseline DeepForest evaluation on TreePolygons."
143 )
144 parser.add_argument(
145 "--root-dir",
146 type=str,
147 default=os.environ.get("MT_ROOT", "/orange/ewhite/web/public/MillionTrees/"),
148 help="Dataset root directory",
149 )
150 parser.add_argument("--batch-size", type=int, default=32)
151 parser.add_argument(
152 "--plot-interval",
153 type=int,
154 default=250,
155 help="Plot every Nth image; set 0 to disable plotting",
156 )
157 parser.add_argument("--output-dir", type=str, default=None)
158 parser.add_argument("--max-batches", type=int, default=None)
159 parser.add_argument(
160 "--mini", action="store_true", help="Use mini datasets for fast dev"
161 )
162 parser.add_argument(
163 "--download", action="store_true", help="Download dataset if missing"
164 )
165 parser.add_argument(
166 "--split-scheme",
167 type=str,
168 default="random",
169 choices=["random", "zeroshot", "crossgeometry"],
170 help="Dataset split scheme",
171 )
172 args = parser.parse_args()
173
174 # Load model
175 model = df_main.deepforest()
176 model.load_model("weecology/deepforest-tree")
177 model.eval()
178
179 # Load dataset
180 polygon_dataset = get_dataset(
181 "TreePolygons",
182 root_dir=args.root_dir,
183 mini=args.mini,
184 download=args.download,
185 split_scheme=args.split_scheme,
186 image_size=224,
187 )
188 test_subset = polygon_dataset.get_subset("test")
189 test_loader = get_eval_loader(
190 "standard", test_subset, batch_size=args.batch_size
191 )
192
193 print(f"There are {len(test_loader)} batches in the test loader")
194
195 all_y_pred: List[dict] = []
196 all_y_true: List[dict] = []
197
198 batch_index = 0
199 for batch in test_loader:
200 metadata, images, targets = batch
201 mt_preds, df_preds = format_deepforest_predictions(
202 images, metadata, targets, model, polygon_dataset, batch_index
203 )
204
205 for image_metadata, y_pred, pred, image_targets, image in zip(
206 metadata, mt_preds, df_preds, targets, images
207 ):
208 if args.plot_interval and args.plot_interval > 0 and (
209 batch_index % args.plot_interval == 0
210 ):
211 plot_eval_result(
212 y_pred, pred, image_targets, image,
213 polygon_dataset, batch_index, args.output_dir
214 )
215
216 all_y_pred.append(y_pred)
217 all_y_true.append(image_targets)
218 batch_index += 1
219
220 if args.max_batches is not None and batch_index >= args.max_batches:
221 break
222
223 results, results_str = polygon_dataset.eval(
224 all_y_pred,
225 all_y_true,
226 test_subset.metadata_array[:len(all_y_true)],
227 )
228 print(results_str)
229
230 if args.output_dir:
231 os.makedirs(args.output_dir, exist_ok=True)
232 output_path = os.path.join(args.output_dir, "results_polygons.txt")
233 with open(output_path, "w", encoding="utf-8") as f:
234 f.write(results_str)
235
236
237if __name__ == "__main__":
238 main()
View Examples on GitHub¶
All example files are also available on GitHub: