import torch
from datasets import load_dataset, load_dataset_builder
Using Hugging Face Image Datasets with fastai
Overview and Setup
The goal of this post is to find an efficient way of using fastai’s data pipeline tools when working with Hugging Face datasets. As it turns out, the DataBlock
and mid-level APIs have some drawbacks in this context, and we’ll see that a better option is to use the low-level API: fastai’s custom DataLoader
and Pipeline
classes. If you just want to see how to use the low-level API, feel free to skip to that section.
The list of Python packages and specific versions used to make this post are in the table below.
Package | Version |
---|---|
python | 3.9.16 |
PyTorch | 2.2.0+cu121 |
fastai | 2.7.14 |
timm | 0.9.12 |
datasets | 2.15.0 |
The latest version of the datasets library at time of writing is 2.16.1, but some of the dataset metadata was missing when using both versions 2.16.0 and 2.16.1.
Get the Metadata
We’ll be using the Fashion-MNIST dataset in our analysis. It’s a drop-in replacement for the original MNIST dataset. We’ll use the version of Fashion-MNIST from the Hugging Face Hub. One nice feature of Hugging Face’s datasets library is the DatasetBuilder
class that contains metadata for a given dataset. You can use load_dataset_builder
to download a DatasetBuilder
for any dataset on the Hugging Face Hub before downloading the dataset itself.
= load_dataset_builder("fashion_mnist") fashion_mnist_builder
The metadata we want is stored in the builder’s info
attribute.
= fashion_mnist_builder.info fashion_mnist_info
For example, there’s a written description of the dataset…
print(fashion_mnist_info.description)
Fashion-MNIST is a dataset of Zalando's article images—consisting of a training set of
60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image,
associated with a label from 10 classes. We intend Fashion-MNIST to serve as a direct drop-in
replacement for the original MNIST dataset for benchmarking machine learning algorithms.
It shares the same image size and structure of training and testing splits.
… a dictionary of pre-defined splits in the dataset, with relevant metadata…
fashion_mnist_info.splits
{'train': SplitInfo(name='train', num_bytes=31296607, num_examples=60000, shard_lengths=None, dataset_name='fashion_mnist'),
'test': SplitInfo(name='test', num_bytes=5233810, num_examples=10000, shard_lengths=None, dataset_name='fashion_mnist')}
… and a dictionary of all of the columns in the dataset, with their datatypes and relevant metadata.
fashion_mnist_info.features
{'image': Image(decode=True, id=None),
'label': ClassLabel(names=['T - shirt / top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'], id=None)}
The ClassLabel
object in info.features
has a convenient int2str
method to convert integer labels to the strings that they represent.
= fashion_mnist_info.features["label"].int2str
fashion_mnist_label_func = len(fashion_mnist_info.features["label"].names)
num_classes for i in range(num_classes)} {i: fashion_mnist_label_func(i)
{0: 'T - shirt / top',
1: 'Trouser',
2: 'Pullover',
3: 'Dress',
4: 'Coat',
5: 'Sandal',
6: 'Shirt',
7: 'Sneaker',
8: 'Bag',
9: 'Ankle boot'}
Get the Data
We can also download and load the dataset using our DatasetBuilder
.
fashion_mnist_builder.download_and_prepare()= fashion_mnist_builder.as_dataset()
fashion_mnist fashion_mnist
DatasetDict({
train: Dataset({
features: ['image', 'label'],
num_rows: 60000
})
test: Dataset({
features: ['image', 'label'],
num_rows: 10000
})
})
The splits, features, and number of training and test instances all match the metadata that we looked at in the previous section. We’ll discard the provided test set and divide the training set into a new training and validation set using the Hugging Face Dataset
class’s built-in train_test_split
method.
= fashion_mnist["train"].train_test_split(test_size=0.2)
fashion_mnist fashion_mnist
DatasetDict({
train: Dataset({
features: ['image', 'label'],
num_rows: 48000
})
test: Dataset({
features: ['image', 'label'],
num_rows: 12000
})
})
The DataBlock
API
One commonly used data processing API in the fastai library is the DataBlock
API. We’ll need to do some additional pre-processing steps when working with a Hugging Face Dataset
. The procedure we use is taken from an example in the documentation for blurr, a library that makes it easy to use Hugging Face transformers with fastai. The specific example I’m referencing is the second example in the linked “Prepare the datasets” section.
Step 1: Record the indices of the validation set.
= len(fashion_mnist["train"]), len(fashion_mnist["test"])
num_train, num_valid = list(range(num_train, num_train + num_valid)) valid_idxs
Step 2: Concatenate the training and validation sets into one dataset.
We’ll use the concatenate_datasets
function from the datasets library.
from datasets import concatenate_datasets
= concatenate_datasets([fashion_mnist["train"],
concat_fashion_mnist "test"]]) fashion_mnist[
Step 3: Pass an IndexSplitter
when creating our DataBlock
so that it knows which instances go in the training and validation sets.
from fastai.vision.all import (
DataBlock, ImageBlock, CategoryBlock, IndexSplitter, CropPad,
RandomCrop, Normalize, imagenet_stats, Rotate, RandomErasing
)
def get_x(item):
return item["image"]
def get_y(item):
return fashion_mnist_label_func(item["label"])
= DataBlock(
dblock =(ImageBlock, CategoryBlock),
blocks=get_x,
get_x=get_y,
get_y=IndexSplitter(valid_idx=valid_idxs),
splitter=[CropPad(32),
item_tfms28)],
RandomCrop(=[Normalize.from_stats(*imagenet_stats),
batch_tfms
Rotate(),
RandomErasing()], )
Now we can create our DataLoaders
.
= dblock.dataloaders(concat_fashion_mnist, bs=512) fashion_mnist_dls
As usual, we can use fastai’s conveniences like show_batch
.
=(5, 5)) fashion_mnist_dls.show_batch(figsize
Then we can create a Learner
and fine-tune a model on our dataset like normal.
from fastai.vision.all import (
vision_learner, resnet18, accuracy,
minimum, steep, valley, slide
)
= vision_learner(fashion_mnist_dls, resnet18, metrics=accuracy)
fashion_mnist_learn = fashion_mnist_learn.lr_find(suggest_funcs=(minimum, steep, valley, slide)) lrs
The graph seems to be steepest at the point marked steep
, so we’ll take that as our learning rate.
3, lrs.steep) fashion_mnist_learn.fine_tune(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.324347 | 0.824672 | 0.700250 | 00:15 |
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.608801 | 0.476613 | 0.848417 | 00:12 |
1 | 0.468581 | 0.308588 | 0.882750 | 00:11 |
2 | 0.400869 | 0.285220 | 0.893000 | 00:11 |
Great!
The Downside of the DataBlock
API with a Hugging Face Dataset
Everything looks fine on the surface, but there’s a hidden cost of using the DataBlock
API with a Hugging Face Dataset
: the time it takes to create the DataLoaders
. Let’s re-run our DataBlock
pipeline and time it.
import time
= time.monotonic()
start_time
# Step 1
= len(fashion_mnist["train"]), len(fashion_mnist["test"])
num_train, num_valid = list(range(num_train, num_train + num_valid))
valid_idxs
# Step 2
= concatenate_datasets([fashion_mnist["train"],
concat_fashion_mnist "test"]])
fashion_mnist[
# Step 3
= DataBlock(
fashion_mnist_dblock =(ImageBlock, CategoryBlock),
blocks=get_x,
get_x=get_y,
get_y=IndexSplitter(valid_idx=valid_idxs),
splitter=[CropPad(32),
item_tfms28)],
RandomCrop(=[Normalize.from_stats(*imagenet_stats),
batch_tfms
Rotate(),
RandomErasing()],
)
# instantiate the dataloaders
= fashion_mnist_dblock.dataloaders(concat_fashion_mnist, bs=512)
fashion_mnist_dls
= time.monotonic()
end_time print("Time elapsed:", round(end_time - start_time, 1), "seconds")
Time elapsed: 18.9 seconds
That’s too long for such a small dataset, and if we create a DataBlock
with a bigger dataset like Imagenette, the time required to instantiate the DataLoaders
will be even longer. Let’s first download the dataset…
= load_dataset("johnowhitaker/imagenette2-320") imagenette
Many thanks to Jonathan Whitaker for uploading this version of Imagenette to the Hugging Face Hub!
… then split the provided training set into a training and validation set…
= imagenette["train"].train_test_split(test_size=0.2)
imagenette imagenette
DatasetDict({
train: Dataset({
features: ['image', 'label'],
num_rows: 10715
})
test: Dataset({
features: ['image', 'label'],
num_rows: 2679
})
})
… then set up some new functions to use in our DataBlock
.
= dict(
imagenette_label_dict ='tench',
n01440764='English springer',
n02102040='cassette player',
n02979186='chain saw',
n03000684='church',
n03028079='French horn',
n03394916='garbage truck',
n03417042='gas pump',
n03425413='golf ball',
n03445777='parachute'
n03888257
)
= imagenette["train"].info.features["label"].int2str
imagenette_label_func
def imagenette_get_y(item):
= imagenette_label_func(item["label"])
label return imagenette_label_dict[label]
The label dictionary was taken from fastai’s “Training Imagenette” tutorial.
Now we’ll run the exact same pre-processing steps, create a DataBlock
and DataLoaders
, time the whole process with time.monotonic
, and keep track of memory usage the %%memit
magic function.
%load_ext memory_profiler
from fastai.vision.all import Resize, aug_transforms
%%memit
= time.monotonic()
start_time
# Step 1
= len(imagenette["train"]), len(imagenette["test"])
num_train, num_valid = list(range(num_train, num_train + num_valid))
valid_idxs
# Step 2
= concatenate_datasets([imagenette["train"],
concat_imagenette "test"]])
imagenette[
# Step 3
= DataBlock(
imagenette_dblock =(ImageBlock, CategoryBlock),
blocks=get_x,
get_x=imagenette_get_y,
get_y=IndexSplitter(valid_idx=valid_idxs),
splitter=[Resize(460)],
item_tfms=[Normalize.from_stats(*imagenet_stats),
batch_tfms*aug_transforms(size=224, min_scale=0.75),
RandomErasing()],
)
# instantiate the dataloaders
= imagenette_dblock.dataloaders(concat_imagenette, bs=64)
imagenette_dls
= time.monotonic()
end_time print("Time elapsed:", round(end_time - start_time, 1), "seconds")
Time elapsed: 35.6 seconds
peak memory: 22003.21 MiB, increment: 15222.07 MiB
That’s a big memory increment! It looks like we need a different approach to using Hugging Face image datasets with fastai.
Image Paths
The Hugging Face datasets documentation indicates that, when loading a dataset that’s stored locally, we can get back paths instead of decoded images by using Dataset
’s cast_column
method. However, if you’re downloading a dataset from the Hugging Face Hub, as we’re doing in this post, you may get the images in a compressed format instead of as individual files, and this option won’t work.
= imagenette.cast_column("image", datasets.Image(decode=False))
imagenette_with_paths list(imagenette_with_paths["train"][0]["image"].keys())
['bytes', 'path']
"train"][0]["image"]["path"] is None imagenette_with_paths[
True
The Mid-Level API
The mid-level API has the same problem as the DataBlock
API. Instantiating a fastai Datasets
or TfmdLists
object using Imagenette has the same problems as the DataBlock
API, so we need to pass to fastai’s low-level API.
The Low-Level API
A faster way to use a Hugging Face image dataset with fastai is via the low-level API: Pipeline
s of fastai Transform
s and fastai’s custom DataLoader
class, which is compatible with PyTorch DataLoader
s but has some additional fastai-flavored features. We will also need to define some custom fastai-style transforms to handle our dataset format. Let’s start by working with Fashion-MNIST.
We’ll start with some data conversion transforms. Hugging Face datasets return indexed items as dictionaries…
= imagenette["train"][0]
sample_input sample_input
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=320x484>,
'label': 7}
… while fastai transforms expect tuples, so we’ll need a transform to convert formats. Instead of writing a __call__
method, fastai transforms need an encodes
method.
from fastai.vision.all import Transform
class DictToTuple(Transform):
= -10 # do this before everything else
order = None # apply to both training and validation sets
split_idx
def encodes(self, x):
return x["image"], x["label"]
We’re relying on the fact that both of the datasets in this post use the column names image
and label
. A more flexible design would pass the relevant column names at instantiation or automatically extract them from the dataset if possible.
Let’s make sure our transform does what we want it to do.
= DictToTuple()
dict_to_tuple = dict_to_tuple(sample_input)
tuple_input tuple_input
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=320x484>, 7)
Instead of calling the encodes
method directly, we just use the instantiated Transform
like a normal Python callable.
Next, fastai’s image transforms work best with its custom PILImage
class, so we’ll need another transform to convert our images to that format.
from PIL import Image
from fastai.vision.all import PILImage
class ConvertImage(Transform):
= -9 # do this after DictToTuple
order = None # apply to both training and validation sets
split_idx
def encodes(self, x: Image.Image):
return PILImage.create(x)
When a fastai Transform
is passed a tuple as input, fastcore’s type dispatch system applies it to individual elements of the tuple that match the type annotation of the encodes
method. For example, ConvertImage
will only be applied to PIL.Image.Image
objects, and our target integer labels will be left alone.
= ConvertImage()
convert_image = convert_image(tuple_input)
image_converted image_converted
(PILImage mode=RGB size=320x484, 7)
Finally, we’ll write a transform to convert the integer labels in our dataset to fastai TensorCategory
labels, and we’ll make sure they know how to decode and display themselves in figures later on by writing a decodes
method.
from fastai.vision.all import TensorCategory
class ConvertCategory(Transform):
= -8 # do this after ConvertImage
order = None # apply to both the training and validation sets
split_idx
def __init__(self, label_func):
self.label_func = label_func
def encodes(self, x: int):
return TensorCategory(x)
def decodes(self, x: TensorCategory):
return self.label_func(x)
As before, fastai’s type dispatching ensures that this transform will only be applied to the integer labels and will leave the images alone.
def label2cat(x):
= imagenette_label_func(x.item())
decoded_label return imagenette_label_dict[decoded_label]
= ConvertCategory(label2cat)
convert_category = convert_category(image_converted)
category_converted category_converted
(PILImage mode=RGB size=320x484, TensorCategory(7))
The decodes
method also uses type dispatching.
= convert_category.decode(category_converted)
category_decoded category_decoded
(PILImage mode=RGB size=320x484, 'gas pump')
Instead of calling the decodes
method of the Transform
directly, we call the decode
method (without the s
).
I could not get fastai’s RandomErasing
to work with the low-level API, so we’ll convert the implementation in timm to a fastai Transform
.
# make sure not to override fastai's RandomErasing class
import timm.data.random_erasing as random_erasing
# fastai's ToTensor transform converts PILImages to fastai TensorImages
from fastai.vision.all import TensorImage
class RandomErase(Transform):
= 100 # do this after Normalize
order = 0 # apply only to the training set
split_idx
def __init__(self, p=0.5, mode='pixel', max_count=1):
= ('cuda' if torch.cuda.is_available()
device else 'mps' if torch.backends.mps.is_available()
else 'cpu')
self.tfm = random_erasing.RandomErasing(
=p, mode=mode, device=device, max_count=max_count
probability
)
def encodes(self, x: TensorImage):
return self.tfm(x)
We’ll always want some specific transforms in our data pre-processing, but we might want to experiment with including or excluding others. To make that easier, we’ll design our data pipelines to be easily extensible. The transforms in a fastai Pipeline
are automatically ordered by their order
class attributes, which makes extensibility much easier to implement.
from fastai.vision.all import ToTensor, IntToFloatTensor, Pipeline
def item_tfms_pipeline(split_idx, label_func, extra_tfms=None):
= [DictToTuple(),
tfms
ConvertImage(),
ConvertCategory(label_func),
ToTensor()]if extra_tfms is not None:
tfms.extend(extra_tfms)return Pipeline(tfms, split_idx=split_idx)
def batch_tfms_pipeline(split_idx, extra_tfms=None):
= [IntToFloatTensor()]
tfms if extra_tfms is not None:
tfms.extend(extra_tfms)return Pipeline(tfms, split_idx=split_idx)
Now we can write some functions that generate a custom DataLoaders
.
from fastai.vision.all import DataLoader, DataLoaders
def get_dl(ds, bs, shuffle, device, label_func, item_tfms=None, batch_tfms=None):
return DataLoader(
=bs, shuffle=shuffle,
ds, bs=item_tfms_pipeline(
after_item=int(not shuffle),
split_idx=label_func,
label_func=item_tfms
extra_tfms
),=batch_tfms_pipeline(
after_batch=int(not shuffle),
split_idx=batch_tfms
extra_tfms
),=device,
device=8,
num_workers
)
def get_dls(ds_dict, bs, item_tfms, batch_tfms, label_dict=None):
def label_func(x):
= ds_dict["train"].info.features["label"].int2str(x)
result if label_dict is not None:
= [label_dict[o] for o in result]
result return result
= ('cuda' if torch.cuda.is_available()
device else 'mps' if torch.backends.mps.is_available()
else 'cpu')
= DataLoaders(
dls *[get_dl(ds_dict[k], bs=bs, shuffle=(k=="train"), device=device,
=label_func, item_tfms=item_tfms, batch_tfms=batch_tfms)
label_funcfor k in ds_dict],
)
# need to set this or vision_learner complains
= len(ds_dict["train"].info.features["label"].names)
dls.c
return dls
We’ve designed things so that the syntax is familiar to what we used with DataBlock
s.
= get_dls(
fashion_mnist_low_level_dls =fashion_mnist,
ds_dict=512,
bs=[CropPad(32),
item_tfms28)],
RandomCrop(=[Normalize.from_stats(*imagenet_stats),
batch_tfms
Rotate(),
RandomErase()] )
Now let’s train a model using our custom DataLoaders
and compare with the DataBlock
API.
from fastai.vision.all import CrossEntropyLossFlat
= vision_learner(
fashion_mnist_low_level_learn =fashion_mnist_low_level_dls,
dls=resnet18,
arch=CrossEntropyLossFlat(),
loss_func=accuracy
metrics
)= fashion_mnist_low_level_learn.lr_find(suggest_funcs=(minimum, steep, valley, slide)) lrs
As usual, we’ll use the learning rate given by the point marked steep
.
3, lrs.steep) fashion_mnist_low_level_learn.fine_tune(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.360952 | 0.741448 | 0.745417 | 00:20 |
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.596736 | 0.389657 | 0.862917 | 00:19 |
1 | 0.468772 | 0.316965 | 0.880583 | 00:16 |
2 | 0.383638 | 0.258381 | 0.902500 | 00:17 |
The training time per epoch is a little longer, but we still save time overall when we take into account the time needed to instantiate the DataLoaders
. Let’s time the DataLoaders
instantiation so we can compare the low-level API with the DataBlock
API:
= time.monotonic()
start_time
= get_dls(
fashion_mnist_low_level_dls =fashion_mnist,
ds_dict=512,
bs=[CropPad(32),
item_tfms28)],
RandomCrop(=[Normalize.from_stats(*imagenet_stats),
batch_tfms
Rotate(),
RandomErase()]
)
= time.monotonic()
end_time
print("Time elapsed:", round(end_time - start_time, 3), "seconds")
Time elapsed: 0.003 seconds
Much faster!
The fastai DataLoader
class doesn’t have a built-in show_batch
method, but we can easily patch one in with fastcore’s @patch
decorator. There is already a show_batch
function in fastai, and patch
takes advantage of the type dispatch system mentioned above to add show_batch
as a method to the DataLoaders
class.
from operator import itemgetter
from fastcore.all import patch, mapt
from fastai.vision.all import show_images
@patch
def show_batch(self: DataLoaders, split='train', max_n=8, figsize=(6, 3)):
= getattr(self, split).one_batch()
batch = itemgetter(slice(max_n))
slicer = mapt(slicer, batch)
batch = getattr(self, split).after_batch.decode(batch)
batch = getattr(self, split).after_item.decode(batch)
batch = batch # unpacking for clarity
images, labels =2, figsize=figsize, titles=labels)
show_images(images, nrows
fashion_mnist_low_level_dls.show_batch()
A fastai Pipeline
has its own decode
method that calls the decode
methods of the constituent transforms in the appropriate reversed order. Each dataloader in a fastai DataLoaders
has its own after_item
and after_batch
pipelines.
Now let’s see how long it takes to instantiate a DataLoaders
and train a model on Imagenette with the low-level API. We’ll need an extra custom Transform
to convert the grayscale images in Imagenette to RGB format.
class ConvertToRGB(Transform):
= 6 # after ToTensor
order = None # apply to both the training and validation sets
split_idx
def encodes(self, x: TensorImage):
if x.shape[0] == 3:
return x
return x.repeat(3, 1, 1)
We’ll also measure the memory usage when instantiating the DataLoaders
.
%%memit
= time.monotonic()
start_time
= get_dls(
imagenette_low_level_dls =imagenette,
ds_dict=64,
bs=[Resize(460), ConvertToRGB()],
item_tfms=[Normalize.from_stats(*imagenet_stats),
batch_tfms*aug_transforms(size=224, min_scale=0.75),
RandomErase()],=imagenette_label_dict
label_dict
)
= time.monotonic()
end_time print("Elapsed time:", round(end_time - start_time, 3), "seconds")
Elapsed time: 0.006 seconds
peak memory: 22117.25 MiB, increment: 0.00 MiB
Much faster, and we’re not using any additional memory!
We can also look at a batch of Imagenette images using our patched-in show_batch
method for DataLoaders
.
=(10, 5)) imagenette_low_level_dls.show_batch(figsize
Now let’s fine-tune a model.
= vision_learner(
imagenette_low_level_learn =imagenette_low_level_dls,
dls=resnet18,
arch=CrossEntropyLossFlat(),
loss_func=accuracy
metrics
)= imagenette_low_level_learn.lr_find(suggest_funcs=(minimum, steep, valley, slide)) lrs
As usual, we’ll take the point marked steep
as our learning rate.
3, lrs.steep) imagenette_low_level_learn.fine_tune(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.334781 | 0.070108 | 0.979097 | 00:23 |
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.197842 | 0.070056 | 0.980590 | 00:24 |
1 | 0.143560 | 0.058491 | 0.983949 | 00:23 |
2 | 0.101656 | 0.048196 | 0.986562 | 00:24 |
Just like with Fashion-MNIST, the training time per epoch for Imagenette is longer with the low-level API than with the DataBlock
API. However, we save so much time and memory usage when creating the DataLoaders
with the low-level API that it wins by default. Our data pipeline is also still quite flexible and extensible for working with Hugging Face datasets, so there’s not much downside.