milliontrees.common package

Subpackages

Submodules

milliontrees.common.data_loaders module

class milliontrees.common.data_loaders.GroupSampler(group_ids, batch_size, n_groups_per_batch, uniform_over_groups, distinct_groups)[source]

Bases: object

Constructs batches by first sampling groups, then sampling data from those groups.

It drops the last batch if it’s incomplete.

milliontrees.common.data_loaders.get_eval_loader(loader, dataset, batch_size, grouper=None, **loader_kwargs)[source]

Constructs and returns the data loader for evaluation.

Parameters:
  • loader (-) – Loader type. ‘standard’ for standard loaders.

  • dataset (-) – Data

  • batch_size (-) – Batch size

  • loader_kwargs (-) – kwargs passed into torch DataLoader initialization.

Output:
  • data loader (DataLoader): Data loader.

milliontrees.common.data_loaders.get_train_loader(loader, dataset, batch_size, uniform_over_groups=None, grouper=None, distinct_groups=True, n_groups_per_batch=None, **loader_kwargs)[source]

Constructs and returns the data loader for training.

Parameters:
  • loader (-) – Loader type. ‘standard’ for standard loaders and ‘group’ for group loaders, which first samples groups and then samples a fixed number of examples belonging to each group.

  • dataset (-) – Data

  • batch_size (-) – Batch size

  • uniform_over_groups (-) – Whether to sample the groups uniformly or according to the natural data distribution. Setting to None applies the defaults for each type of loaders. For standard loaders, the default is False. For group loaders, the default is True.

  • grouper (-) – Grouper used for group loaders or for uniform_over_groups=True

  • distinct_groups (-) – Whether to sample distinct_groups within each minibatch for group loaders.

  • n_groups_per_batch (-) – Number of groups to sample in each minibatch for group loaders.

  • loader_kwargs (-) – kwargs passed into torch DataLoader initialization.

Output:
  • data loader (DataLoader): Data loader.

milliontrees.common.grouper module

class milliontrees.common.grouper.CombinatorialGrouper(dataset, groupby_fields)[source]

Bases: Grouper

group_field_str(group)[source]
Parameters:

group (-) – A single integer representing a group.

Output:
  • group_str (str): A string containing the name of that group.

group_str(group)[source]
Parameters:

group (-) – A single integer representing a group.

Output:
  • group_str (str): A string containing the pretty name of that group.

metadata_to_group(metadata, return_counts=False)[source]
Parameters:
  • metadata (-) – An n x d matrix containing d metadata fields for n different points.

  • return_counts (-) – If True, return group counts as well.

Output:
  • group (Tensor): An n-length vector of groups.

  • group_counts (Tensor): Optional, depending on return_counts.

    An n_group-length vector of integers containing the numbers of data points in each group in the metadata.

class milliontrees.common.grouper.Grouper[source]

Bases: object

Groupers group data points together based on their metadata.

They are used for training and evaluation, e.g., to measure the accuracies of different groups of data.

group_field_str(group)[source]
Parameters:

group (-) – A single integer representing a group.

Output:
  • group_str (str): A string containing the name of that group.

group_str(group)[source]
Parameters:

group (-) – A single integer representing a group.

Output:
  • group_str (str): A string containing the pretty name of that group.

metadata_to_group(metadata, return_counts=False)[source]
Parameters:
  • metadata (-) – An n x d matrix containing d metadata fields for n different points.

  • return_counts (-) – If True, return group counts as well.

Output:
  • group (Tensor): An n-length vector of groups.

  • group_counts (Tensor): Optional, depending on return_counts.

    An n_group-length vector of integers containing the numbers of data points in each group in the metadata.

property n_groups

The number of groups defined by this Grouper.

milliontrees.common.utils module

milliontrees.common.utils.avg_over_groups(v, g, n_groups)[source]
Parameters:
  • v (Tensor) – Vector containing the quantity to average over.

  • g (Tensor) – Vector of the same length as v, containing group information.

Returns:

Vector of length num_groups group_counts (Tensor)

Return type:

group_avgs (Tensor)

milliontrees.common.utils.format_eval_results(results: Dict[str, Any], dataset) str[source]

Format evaluation results into well-formatted tables.

Parameters:
  • results – Dictionary containing evaluation results

  • dataset – Dataset object with source mapping information

Returns:

Formatted string with tables

milliontrees.common.utils.get_counts(g, n_groups)[source]

This differs from split_into_groups in how it handles missing groups.

get_counts always returns a count array of length n_groups, whereas split_into_groups returns a unique_counts array whose length is the number of unique groups present in g.

Parameters:

g (-) – Vector of groups

Returns:

An array of length n_groups, denoting the count of each group.

Return type:

  • counts (ndarray)

milliontrees.common.utils.map_to_id_array(df, ordered_map={})[source]
milliontrees.common.utils.maximum(numbers, empty_val=0.0)[source]
milliontrees.common.utils.minimum(numbers, empty_val=0.0)[source]
milliontrees.common.utils.numel(obj)[source]
milliontrees.common.utils.shuffle_arr(arr, seed=None)[source]
milliontrees.common.utils.split_into_groups(g)[source]

Splits the input tensor into unique groups and their corresponding indices.

Parameters:

g (Tensor) – A vector containing group labels.

Returns:

  • groups (Tensor): A tensor containing the unique group labels present in g.

  • group_indices (list of Tensors): A list where each tensor contains the indices of elements in g that correspond to the respective group in groups.

  • unique_counts (Tensor): A tensor representing the count of each unique group in groups, with the same length as groups.

Return type:

tuple

milliontrees.common.utils.subsample_idxs(idxs, num=5000, take_rest=False, seed=None)[source]
milliontrees.common.utils.threshold_at_recall(y_pred, y_true, global_recall=60)[source]

Calculate the model threshold used to achieve a desired global_recall level.

Parameters:
  • y_pred (Description of y_pred, Assumes that y_true is a vector of the true binary labels.)

  • y_true (Description of y_true.)

  • global_recall (Description of global_recall.)

Module contents