I’ve been working with PyTorch for a while now, and I keep coming back to the same sticking point whenever I start a new project: custom datasets. The built-in datasets are great for learning, but the moment you want to work with your own images, audio files, or any data that doesn’t fit the standard mold, you’re faced with building something from scratch. That’s exactly what I want to walk through in this article.
In this tutorial we’ll build custom datasets in PyTorch from the ground up. I’ll show you how to load unlabeled images from a folder, labeled images using PyTorch’s ImageFolder, and even how to adapt the same patterns to audio data. By the end you’ll have a reusable pattern you can drop into any project.
TLDR
- Custom datasets in PyTorch inherit from
torch.utils.data.Datasetand must implement__len__and__getitem__ - For unlabeled images from a folder, build a custom class that lists files with
os.listdirand loads them with PIL - For labeled images,
torchvision.datasets.ImageFolderhandles the directory structure automatically - The same dataset pattern works for audio, CSV data, or any file type by swapping the loading logic inside
__getitem__
How PyTorch Dataset and DataLoader Work Together
Before we write any code, I want to make sure we are on the same page about how PyTorch actually handles data. The two primitives you need to know are Dataset and DataLoader.
A Dataset stores your samples and their labels. It knows nothing about training loops or batch sizes. It just knows how to hand you one item at a time via dataset[i]. A DataLoader wraps that and gives you batching, shuffling, multiprocessing, and all the machinery that makes training fast.
The key thing that tripped me up early on: Dataset uses lazy loading by default. It does not load all your images into memory when you create the object. Instead, it reads from disk only when DataLoader requests a specific index. This is what makes PyTorch memory efficient for large image collections, but it also means your __getitem__ implementation needs to be fast because it runs every time a batch is fetched.
When you create a custom dataset you inherit from torch.utils.data.Dataset and implement two methods:
__len__– returns the total number of samples solen(dataset)works__getitem__– takes an index and returns the sample and label at that position
That’s the entire contract. Everything else is up to you. Now let’s see what this looks like in practice.
Loading Unlabeled Images from a Folder
Let’s start with the simplest case: you have a folder full of images with no labels, and you just want to load them all for training a GAN or an autoencoder. No subfolders, no class names, just a flat directory of images.
Here’s a custom dataset I have used for exactly this:
import os
from PIL import Image
from torch.utils.data import Dataset
class LoadFromFolder(Dataset):
def __init__(self, main_dir, transform=None):
self.main_dir = main_dir
self.transform = transform
all_imgs = os.listdir(main_dir)
self.total_imgs = natsorted(all_imgs)
def __len__(self):
return len(self.total_imgs)
def __getitem__(self, idx):
img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
image = Image.open(img_loc).convert("RGB")
if self.transform:
tensor_image = self.transform(image)
return tensor_image
return image
You can then pair this with a DataLoader like this:
from torch.utils.data import DataLoader
dataset = LoadFromFolder(main_dir="./data", transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
print(next(iter(dataloader)).shape)
The transform parameter is typically a composition of torchvision.transforms like resizing, cropping, and converting to a tensor. Here’s a common pattern:
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
One thing I almost forgot when I first wrote this: natsorted from the natsort package gives you natural sorting so that image10.jpg comes after image9.jpg instead of being treated as a string. It’s a small detail that saves you debugging time later.
Loading Labeled Images with ImageFolder
The unlabeled case is straightforward because there is no labeling logic. But what if you have a classification problem where each image belongs to a class? That’s where torchvision.datasets.ImageFolder saves you a lot of work.
ImageFolder assumes your directory structure follows a specific convention: each subfolder is a class, and the subfolder name is the class label. Like this:
root/
cat/
cat001.jpg
cat002.jpg
dog/
dog001.jpg
dog002.jpg
Given that structure, you can load the dataset in two lines:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
dataset = ImageFolder(root="./data", transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
ImageFolder automatically builds a class-to-index mapping. You can access it via dataset.class_to_idx if you need to check which index corresponds to which label. The dataset returns a tuple of (image, label_index) when you index it.
If you need more control over how images are labeled or loaded, you can subclass ImageFolder and override the methods that handle file loading. But for most classification tasks the base class does exactly what you need.
Loading Custom Data Types Beyond Images
The dataset pattern is not limited to images at all. I have used the same structure to load audio spectrograms, CSV time-series data, and even proprietary binary formats. The only part that changes is the loading and transformation logic inside __getitem__.
Here’s a skeleton for a spectrogram dataset that I have found useful for audio classification tasks:
from torch.utils.data import Dataset
class SpectrogramDataset(Dataset):
def __init__(self, file_label_ds, transform=None, audio_path=""):
self.ds = file_label_ds
self.transform = transform
self.audio_path = audio_path
def __len__(self):
return len(self.ds)
def __getitem__(self, index):
file, label = self.ds[index]
spectrogram = self.transform(self.audio_path + file)
return spectrogram, label
The file_label_ds parameter in this example is any dataset that gives you a filename and a label. It could be a simple list of tuples, a CSV file read with pandas, or another PyTorch dataset. The key is that your custom dataset wraps it and handles the file-to-tensor conversion.
The same idea applies to tabular data. Replace the image loading with a pandas DataFrame.iloc[index] read, apply any preprocessing transforms, and return the feature tensor and label.
FAQ
Q: What is the difference between Dataset and DataLoader in PyTorch?
Dataset is an abstract class that provides the __len__ and __getitem__ interface for accessing individual samples. DataLoader wraps one or more Dataset objects and provides batching, shuffling, multiprocessing workers, and automatic collation of samples into batches.
Q: How do you create a custom dataset in PyTorch?
Create a class that inherits from torch.utils.data.Dataset and implement __len__ (returning the sample count) and __getitem__ (returning the sample at a given index). The __getitem__ method typically loads data from disk, applies transforms, and returns the sample and label.
Q: Does PyTorch load all images into memory at once?
No. By default, PyTorch’s Dataset uses lazy loading. Data is read from disk only when DataLoader requests a specific index via __getitem__. This keeps memory usage low even for large datasets, provided your __getitem__ implementation does not accumulate data across calls.
Q: How does ImageFolder determine class labels?
ImageFolder uses the subfolder names in the root directory as class names. Each subfolder’s name becomes a class label, and all images inside that subfolder are assigned to that class. The class-to-index mapping is stored in dataset.class_to_idx.
Q: Can a custom dataset return anything other than images?
Yes. The Dataset interface is completely agnostic to data type. Any data that can be represented as a tensor or a collection of tensors can be returned from __getitem__. Audio spectrograms, tabular data, text embeddings, and binary sensor streams have all been packaged as PyTorch datasets using this pattern.
Wrapping Up
If you take nothing else from this article, remember this: custom datasets in PyTorch are just about defining two methods. __len__ tells PyTorch how many samples you have, and __getitem__ handles the loading. Everything else from batching to multiprocessing is handled by DataLoader, which means you can focus on the data and not the plumbing.
The examples I covered here should cover the most common cases you will hit. Unlabeled images from a folder, labeled images via the directory structure, and the general pattern for adapting any file format into a dataset. Once you see how simple the interface is, you will find yourself reusing this pattern constantly.

