Datasets¶
All datasets are subclasses of torch.utils.data.Dataset
i.e, they have __getitem__
and __len__
methods implemented.
Hence, they can all be passed to a torch.utils.data.DataLoader
which can load multiple samples parallelly using torch.multiprocessing
workers.
For example:
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
The following datasets are available:
All the datasets have almost similar API. They all have two common arguments:
transform
and target_transform
to transform the input and target respectively.
MNIST¶
-
class
torchvision.datasets.
MNIST
(root, train=True, transform=None, target_transform=None, download=False)[source]¶ MNIST Dataset.
Parameters: - root (string) – Root directory of dataset where
processed/training.pt
andprocessed/test.pt
exist. - train (bool, optional) – If True, creates dataset from
training.pt
, otherwise fromtest.pt
. - download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- root (string) – Root directory of dataset where
COCO¶
Note
These require the COCO API to be installed
Detection¶
-
class
torchvision.datasets.
CocoDetection
(root, annFile, transform=None, target_transform=None)[source]¶ MS Coco Captions Dataset.
Parameters: - root (string) – Root directory where images are downloaded to.
- annFile (string) – Path to json annotation file.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.ToTensor
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
LSUN¶
-
class
torchvision.datasets.
LSUN
(db_path, classes='train', transform=None, target_transform=None)[source]¶ LSUN dataset.
Parameters: - db_path (string) – Root directory for the database files.
- classes (string or list) – One of {‘train’, ‘val’, ‘test’} or a list of categories to load. e,g. [‘bedroom_train’, ‘church_train’].
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
ImageFolder¶
-
class
torchvision.datasets.
ImageFolder
(root, transform=None, target_transform=None, loader=<function default_loader>)[source]¶ A generic data loader where the images are arranged in this way:
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
Parameters: - root (string) – Root directory path.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- loader – A function to load an image given its path.
Imagenet-12¶
This should simply be implemented with an ImageFolder
dataset.
The data is preprocessed as described
here
CIFAR¶
-
class
torchvision.datasets.
CIFAR10
(root, train=True, transform=None, target_transform=None, download=False)[source]¶ CIFAR10 Dataset.
Parameters: - root (string) – Root directory of dataset where directory
cifar-10-batches-py
exists. - train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- root (string) – Root directory of dataset where directory
STL10¶
-
class
torchvision.datasets.
STL10
(root, split='train', transform=None, target_transform=None, download=False)[source]¶ STL10 Dataset.
Parameters: - root (string) – Root directory of dataset where directory
stl10_binary
exists. - split (string) – One of {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’}. Accordingly dataset is selected.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- root (string) – Root directory of dataset where directory
SVHN¶
-
class
torchvision.datasets.
SVHN
(root, split='train', transform=None, target_transform=None, download=False)[source]¶ SVHN Dataset. Note: The SVHN dataset assigns the label 10 to the digit 0. However, in this Dataset, we assign the label 0 to the digit 0 to be compatible with PyTorch loss functions which expect the class labels to be in the range [0, C-1]
Parameters: - root (string) – Root directory of dataset where directory
SVHN
exists. - split (string) – One of {‘train’, ‘test’, ‘extra’}. Accordingly dataset is selected. ‘extra’ is Extra training set.
- transform (callable, optional) – A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
- root (string) – Root directory of dataset where directory
PhotoTour¶
-
class
torchvision.datasets.
PhotoTour
(root, name, train=True, transform=None, download=False)[source]¶ Learning Local Image Descriptors Data Dataset.
Parameters: - root (string) – Root directory where images are.
- name (string) – Name of the dataset to load.
- transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.