Datasets

All datasets are subclasses of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples parallelly using torch.multiprocessing workers. For example:

imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

The following datasets are available:

All the datasets have almost similar API. They all have two common arguments: transform and target_transform to transform the input and target respectively.

MNIST

class torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)[source]

MNIST Dataset.

Parameters:
  • root (string) – Root directory of dataset where processed/training.pt and processed/test.pt exist.
  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

COCO

Note

These require the COCO API to be installed

Captions

class torchvision.datasets.CocoCaptions(root, annFile, transform=None, target_transform=None)[source]

MS Coco Captions Dataset.

Parameters:
  • root (string) – Root directory where images are downloaded to.
  • annFile (string) – Path to json annotation file.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.ToTensor
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

Example

import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
                        annFile = 'json annotation file',
                        transform=transforms.ToTensor())

print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample

print("Image Size: ", img.size())
print(target)

Output:

Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:Tuple (image, target). target is a list of captions for the image.
Return type:tuple

Detection

class torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None)[source]

MS Coco Captions Dataset.

Parameters:
  • root (string) – Root directory where images are downloaded to.
  • annFile (string) – Path to json annotation file.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.ToTensor
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:Tuple (image, target). target is the object returned by coco.loadAnns.
Return type:tuple

LSUN

class torchvision.datasets.LSUN(db_path, classes='train', transform=None, target_transform=None)[source]

LSUN dataset.

Parameters:
  • db_path (string) – Root directory for the database files.
  • classes (string or list) – One of {‘train’, ‘val’, ‘test’} or a list of categories to load. e,g. [‘bedroom_train’, ‘church_train’].
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:Tuple (image, target) where target is the index of the target category.
Return type:tuple

ImageFolder

class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)[source]

A generic data loader where the images are arranged in this way:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Parameters:
  • root (string) – Root directory path.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • loader – A function to load an image given its path.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(image, target) where target is class_index of the target class.
Return type:tuple

Imagenet-12

This should simply be implemented with an ImageFolder dataset. The data is preprocessed as described here

Here is an example.

CIFAR

class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)[source]

CIFAR10 Dataset.

Parameters:
  • root (string) – Root directory of dataset where directory cifar-10-batches-py exists.
  • train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(image, target) where target is index of the target class.
Return type:tuple

STL10

class torchvision.datasets.STL10(root, split='train', transform=None, target_transform=None, download=False)[source]

STL10 Dataset.

Parameters:
  • root (string) – Root directory of dataset where directory stl10_binary exists.
  • split (string) – One of {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’}. Accordingly dataset is selected.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(image, target) where target is index of the target class.
Return type:tuple

SVHN

class torchvision.datasets.SVHN(root, split='train', transform=None, target_transform=None, download=False)[source]

SVHN Dataset. Note: The SVHN dataset assigns the label 10 to the digit 0. However, in this Dataset, we assign the label 0 to the digit 0 to be compatible with PyTorch loss functions which expect the class labels to be in the range [0, C-1]

Parameters:
  • root (string) – Root directory of dataset where directory SVHN exists.
  • split (string) – One of {‘train’, ‘test’, ‘extra’}. Accordingly dataset is selected. ‘extra’ is Extra training set.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(image, target) where target is index of the target class.
Return type:tuple

PhotoTour

class torchvision.datasets.PhotoTour(root, name, train=True, transform=None, download=False)[source]

Learning Local Image Descriptors Data Dataset.

Parameters:
  • root (string) – Root directory where images are.
  • name (string) – Name of the dataset to load.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
__getitem__(index)[source]
Parameters:index (int) – Index
Returns:(data1, data2, matches)
Return type:tuple