FastAI: Picking the Best Model

In the Beginning

In this notebook we'll go over the fastai course lesson 3 - "Which image models are best?". We'll use the benchmarking data from timm, a collection of pyTorch IMage Models to compare how different computer vision models performed using time-per-image and accuracy as our metrics.

Imports and Setup

# from python
from functools import partial
from pathlib import Path

# from pypi
from tabulate import tabulate

import altair
import pandas

# monkey
from graeae.visualization.altair_helpers import output_path, save_chart
TABLE = partial(tabulate, tablefmt="orgtbl", headers=["Column", "Value"] )

PLOT_WIDTH, PLOT_HEIGHT = 900, 600
SLUG = "fastai-picking-the-best-model"
OUTPUT_PATH = output_path(SLUG)
save_it = partial(save_chart, output_path=OUTPUT_PATH)

The Validation Data

We'll be using data that's part of the git repository for timm . Once you clone the repository the first file within it that we want will be results/results-imagenet.csv. This is the result of using the Imagenet Validation set to validate the models.

RESULTS = Path("~/projects/third-party/"
               "pytorch-image-models/results").expanduser()
DATA = RESULTS/"results-imagenet.csv"
validation = pandas.read_csv(DATA)

print(TABLE(validation.iloc[0].to_frame()))
Column Value
model beit_large_patch16_512
top1 88.602
top1_er r 11.398
top5 98.656
top5_err 1.344
param_count 305.67
img_size 512
crop_pct 1.0
interpolation bicubic

This table shows the first row of the results-imagenet CSV. Each row represents a computer vision model and some information about how it performed during validation. The documentation says that top1 and top5 are "top-1/top-5 differences from clean validation." Which means… what? Looking at the validate.py file it appears that top1 and top5 are measures of accuracy. Looking in the utils.metrics.py module the function accuracy has a docstring that says: Computes the accuracy over the k top predictions for the specified values of k. The top1 and top5 are AverageMeter objects that keep a running average of their accuracies.

This seems straightforward enough, but if you look at that first row the top1 is smaller than the top5 and has a larger error…

Guessing by the name, the model in our row is an instance of "BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)" found in timm's beit.py module.

print(validation.shape)
(668, 9)

The model column is the string you use when creating a model and also refers to a function in one of the pytorch-image-models/timm/models modules. If you want to see how the model in our example row is defined, look in the timm/models/beit.py module for a function named "beit_large_patch16_512". You should find something like this.

@register_model
def beit_large_patch16_512(pretrained=False, **kwargs):
    model_kwargs = dict(
        img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
        use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
    model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **model_kwargs)
    return model

So we can now see that besides being a BEIT model the name tells us that it used an image size of 512 and a patch size of 16. Further up the file is this configuration:

'beit_large_patch16_512': _cfg(
    url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
        input_size=(3, 512, 512), crop_pct=1.0,

Which tells you where the pretrained weights came from.

The Benchmark Data

We're going to merge our "validation" data with two "benchmark" files (also in the "results" folder) doing some cryptic filtering and data wrangling. It's not obvious what everything is doing so let's use it first and maybe figure out most of it later. The main things to note is that we're adding a family column made by taking the first token from the model name (e.g. the model beit_large_patch16_512 gets the family beit), we're adding a secs column by inverting the samples-per-second column, and filtering the models down to a subset that are useful to look at.

BENCHMARK_FILE = ("benchmark-{infer_or_train}"
                  "-amp-nhwc-pt111-cu113-rtx3090.csv")
SAMPLE_RATE = "{infer_or_train}_samples_per_sec"
FAMILY_REGEX = r'^([a-z]+?(?:v2)?)(?:\d|_|$)'
FAMILY_FILTER = r'^re[sg]netd?|beit|convnext|levit|efficient|vit|vgg'

def get_data(infer_or_train: str,
             validation: pandas.DataFrame=validation) -> pandas.DataFrame:
    """Load a benchmark dataframe

    Args:
     infer_or_train: part of filename with label (infer or train)
     validation: DataFrame created from validation results file

    Returns:
     benchmark data merged with validation
    """

    frame = pandas.read_csv(
        RESULTS/BENCHMARK_FILE.format(
            infer_or_train=infer_or_train)).merge(
        validation, on='model')
    frame['secs'] = 1. / frame[SAMPLE_RATE.format(infer_or_train=infer_or_train)]
    frame['family'] = frame.model.str.extract(FAMILY_REGEX)
    frame = frame[~frame.model.str.endswith('gn')]
    IN_FILTERER = frame.model.str.contains('in22'), "family"
    frame.loc[IN_FILTERER] = frame.loc[IN_FILTERER] + '_in22'

    RESNET_FILTERER = frame.model.str.contains('resnet.*d'),'family'
    frame.loc[RESNET_FILTERER] = frame.loc[RESNET_FILTERER] + 'd'
    return frame[frame.family.str.contains(FAMILY_FILTER)]

Build The Base Chart

The build_chart function is going to help us build the basic chart to compare the merged validation and benchmark values for the models.

SELECTION = altair.selection_multi(fields=["family"], bind="legend")
COLUMNS = ["secs", "top1", "family", "model"]

def build_chart(frame: pandas.DataFrame, infer_or_train: str,
                add_selection: bool=True) -> altair.Chart:
    """Build the basic chart for our benchmarks

    Note:
     the ``add_selection`` function can only be called once on a chart so to
     add more layers don't add it here, add it later to the end

    Args:
     frame: benchmark frame to plot
     infer_or_train: which image size column (infer | train)
     add_selection: whether to add the selection at the end
    """
    # altair includes all the data even if it's not used in the plot
    # reducing the dataframe to just the data you need
    # makes the file smaller
    SIZE = f"{infer_or_train}_img_size"
    frame = frame[COLUMNS + [SIZE]]
    chart = altair.Chart(frame).mark_circle().encode(
        x=altair.X("secs", scale=altair.Scale(type="log"),
                   axis=altair.Axis(title="Seconds Per Image (log)")),
        y=altair.Y("top1",
                   scale=altair.Scale(zero=False),
                   axis=altair.Axis(title="Imagenet Accuracy")),
        size=altair.Size(SIZE,
                         scale=altair.Scale(
                             type="pow", exponent=2)),
        color="family",
        tooltip=[altair.Tooltip("family", title="Architecture Family"),
                 altair.Tooltip("model", title="Model"),
                 altair.Tooltip(SIZE, format=",", title="Image Size"),
                 altair.Tooltip("top1", title="Accuracy"),
                 altair.Tooltip("secs", title="Time (sec)", format=".2e")
                 ]
        )
    if add_selection:
        chart = chart.encode(opacity=altair.condition(
                SELECTION,
                altair.value(1),
                altair.value(0.1))
        ).add_selection(SELECTION)
    return chart

Plot All the Architectures

Our first chart for the benchmarking data will plot all the models left in the data-frame after our filtering and merging to show us how they compare for accuracy and average time to process a sample.

def plot_it(frame: pandas.DataFrame,
            title: str,
            filename: str,
            infer_or_train: str,
            width: int=PLOT_WIDTH,
            height: int=PLOT_HEIGHT) -> None:
    """Make an altair plot of the frame

    Args:
     frame: benchmark frame to plot
     title: title to give the plot
     filename: name of file to save the chart to
     infer_or_train: which image size column (infer or train)
     width: width of plot in pixels
     height: height of plot in pixels
    """
    chart = build_chart(frame, infer_or_train).properties(
        title=title,
        width=width,
        height=height,
    )

    save_it(chart, filename)
    return

Plot Some of the Architectures

To make it easier to understand, the author of the fastai lesson chose a subset of the families to plot.

  • beit
  • convnext
  • efficientnetv2
  • levit
  • regnetx
  • resnetd
  • vgg

Note: The fastai notebook points out that because of the different sample sizes used to train the models it isn't a simple case of picking the "best" performing model (given a speed vs accuracy trade off). The pytorch-image-models repository has information to help research what went into the training.

FAMILIES = 'levit|resnetd?|regnetx|vgg|convnext.*|efficientnetv2|beit'

def subset_regression(frame: pandas.DataFrame,
                      title: str,
                      filename: str,
                      infer_or_train: str,
                      width: int=PLOT_WIDTH,
                      height: int=PLOT_HEIGHT) -> None:
    """Plot subset of model-families

    Args:
     frame: frame with benchmark data
     title: title to give the plot
     filename: name to save the file
     infer_or_train: which image size column
     width: width of plot in pixels
     height: height of plot in pixels
    """
    subset = frame[frame.family.str.fullmatch(FAMILIES)]

    base = build_chart(subset, infer_or_train, add_selection=False)

    line = base.transform_regression(
        "secs", "top1",
        groupby=["family"],
        method="log",
        ).mark_line().encode(
            opacity=altair.condition(
                SELECTION,
                altair.value(1),
                altair.value(0.1)
            ))

    chart = base.encode(
        opacity=altair.condition(
        SELECTION,
        altair.value(1),
        altair.value(0.1)
    ))

    chart = altair.layer(chart, line).properties(
        title=title,
        width=width,
        height=height,
    ).add_selection(SELECTION)

    save_it(chart, filename)
    return

Inference

The first benchmarking data we're going to add is the inference data. Unfortunately I haven't been able to find out what this means, exactly - was this a test of categorizing a test set? It only adds the average sample time to what we're going to plot, which perhaps isn't as interesting as the accuracy anyway.

inference = get_data('infer')
print(TABLE(inference.iloc[0].to_frame()))
Column Value
model levit_128s
infer_samples_per_sec 21485.8
infer_step_time 47.648
infer_batch_size 1024
infer_img_size 224
param_count_x 7.78
top1 76.514
top1_err 23.486
top5 92.87
top5_err 7.13
param_count_y 7.78
img_size 224
crop_pct 0.9
interpolation bicubic
secs 4.654236751715086e-05
family levit

Let's look at a row of what was added to our original validation data.

added = inference[list(set(inference.columns) - set(validation.columns))].iloc[0]
print(TABLE(added.to_frame()))
Column Value
secs 4.654236751715086e-05
family levit
param_count_y 7.78
infer_batch_size 1024
param_count_x 7.78
infer_samples_per_sec 21485.8
infer_step_time 47.648
infer_img_size 224

If you look back at get_data you'll see that we added the sec column which is defined as \(\frac{1}{\textit{samples per second}}\). So it's the averaged(?) seconds per sample. I think.

Let's see how evenly distributed the families are.

counts = inference.family.value_counts().to_frame().reset_index().rename(
    columns = {"index": "Family", "family": "Count"})

chart = altair.Chart(counts).mark_bar().encode(
    x="Count", y=altair.Y("Family", sort="-x"), tooltip=["Count"],
).properties(
    width=PLOT_WIDTH,
    height=PLOT_HEIGHT,
    title="Inference Family Counts"
)

save_it(chart, "inference-family-counts")

Figure Missing

There doesn't seem to be an even representation of model families. Let's look at the accuracy vs the speed for the models.

plot_it(inference, title="Inference", 
        filename="inference-benchmark",
        infer_or_train="infer")

Figure Missing

While we still don't have an explanation of exactly what we're looking at, in the broadest it's a plot of the time it takes for a model to process an image (in seconds on a logarithmic scale) versus the accuracy when categorizing the Imagenet dataset.

  • The color matches the family in the legend.
  • The size is proportional to the number of seconds it took.
  • Clicking on a family in the legend will highlight it and suppress the other families.
  • Hovering over a circle gives the exact information for that point.

I believe that the accuracy is the best performance for a model, so even though a family might have multiple points in the plot, each model will only have one point to represent its best accuracy and the time it took.

A Subset

To make it easier to see what's going on the author(s) of the fastai lesson paired down the dataset to a subset of families and then added regression lines to compare them.

subset_regression(inference,
                  title="Inference Subset",
                  filename="inference-subset-benchmark",
                  infer_or_train="infer")

Figure Missing

Training

training = get_data("train")
plot_it(training, title="Training", 
        filename="training-benchmark",
        infer_or_train="train")

Figure Missing

subset_regression(training,
                  title="Training Subset",
                  filename="training-subset-benchmark",
                  infer_or_train="train")

Figure Missing

Parameters Vs Time

The fastai notebook plots the model parameters vs time (speed), saying that parameters are sometimes used as a proxy for speed and memory use (to make it machine independent, presumably), but then says that it isn't always a good proxy. Once more they give us a tool and then tell us it isn't necessarily what to use.

plotter = inference[["param_count_x", "secs", "infer_img_size", "family", "model", "top1"]]
chart = altair.Chart(plotter).mark_circle().encode(
    x=altair.X("param_count_x", scale=altair.Scale(type="log"),
               axis=altair.Axis(title="Parameters (log)")),
    y=altair.Y("secs", scale=altair.Scale(type="log", zero=False),
               axis=altair.Axis(title="Seconds Per Image (log)")),
    color="infer_img_size",
    tooltip=[altair.Tooltip("family", title="Architecture Family"),
             altair.Tooltip("model", title="Model"),
             altair.Tooltip("infer_img_size", format=",", title="Image Size"),
             altair.Tooltip("top1", title="Accuracy"),
             altair.Tooltip("secs", title="Time (sec)", format=".2e")
             ],
    opacity=altair.condition(
        SELECTION,
        altair.value(1),
        altair.value(0.1))
).add_selection(SELECTION).properties(
    title="Parameters Vs Time",
    width=PLOT_WIDTH,
    height=PLOT_HEIGHT-100)

save_it(chart, "inference-parameters-vs-time")

Figure Missing

In this case it looks like parameters and speed are correlated, as it takes more time the more parameters there are, but it's confounded by the fact that the models with more parameters seem to be handling bigger images.

Accuracy Vs Size

The fastai

plotter = inference[["param_count_x", "img_size",
                     "family", "model", "secs", "top1"]]
chart = altair.Chart(plotter).mark_circle().encode(
    x=altair.X("img_size", scale=altair.Scale(zero=False),
               axis=altair.Axis(title="Image Size")),
    y=altair.Y("top1",
               scale=altair.Scale(zero=False),
               axis=altair.Axis(title="Accuracy")),
    size=altair.Size("secs", scale=altair.Scale(type="log")),
    color="family",
    tooltip=[altair.Tooltip("family", title="Architecture Family"),
             altair.Tooltip("model", title="Model"),
             altair.Tooltip("img_size", format=",", title="Image Size"),
             altair.Tooltip("top1", title="Accuracy"),
             altair.Tooltip("secs", title="Time (sec)", format=".2e")
             ],
    opacity=altair.condition(
        SELECTION,
        altair.value(1),
        altair.value(0.1))
).add_selection(SELECTION).properties(
    title="Accuracy Vs Image Size",
    width=PLOT_WIDTH,
    height=PLOT_HEIGHT-100)

save_it(chart, "inference-accuracy-vs-size")

Figure Missing

Sources