Source code for torchvision.datasets.phototour

import os
import errno
import numpy as np
from PIL import Image

import torch
import as data

from .utils import download_url, check_integrity

[docs]class PhotoTour(data.Dataset): """`Learning Local Image Descriptors Data <>`_ Dataset. Args: root (string): Root directory where images are. name (string): Name of the dataset to load. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ urls = { 'notredame': [ '', '', '509eda8535847b8c0a90bbb210c83484' ], 'yosemite': [ '', '', '533b2e8eb7ede31be40abc317b2fd4f0' ], 'liberty': [ '', '', 'fdd9152f138ea5ef2091746689176414' ], } mean = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437} std = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019} lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092} image_ext = 'bmp' info_file = 'info.txt' matches_files = 'm50_100000_100000_0.txt' def __init__(self, root, name, train=True, transform=None, download=False): self.root = os.path.expanduser(root) = name self.data_dir = os.path.join(self.root, name) self.data_down = os.path.join(self.root, '{}.zip'.format(name)) self.data_file = os.path.join(self.root, '{}.pt'.format(name)) self.train = train self.transform = transform self.mean = self.mean[name] self.std = self.std[name] if download: if not self._check_datafile_exists(): raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') # load the serialized data, self.labels, self.matches = torch.load(self.data_file)
[docs] def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (data1, data2, matches) """ if self.train: data =[index] if self.transform is not None: data = self.transform(data) return data m = self.matches[index] data1, data2 =[m[0]],[m[1]] if self.transform is not None: data1 = self.transform(data1) data2 = self.transform(data2) return data1, data2, m[2]
def __len__(self): if self.train: return self.lens[] return len(self.matches) def _check_datafile_exists(self): return os.path.exists(self.data_file) def _check_downloaded(self): return os.path.exists(self.data_dir) def download(self): if self._check_datafile_exists(): print('# Found cached data {}'.format(self.data_file)) return if not self._check_downloaded(): # download files url = self.urls[][0] filename = self.urls[][1] md5 = self.urls[][2] fpath = os.path.join(self.root, filename) download_url(url, self.root, filename, md5) print('# Extracting data {}\n'.format(self.data_down)) import zipfile with zipfile.ZipFile(fpath, 'r') as z: z.extractall(self.data_dir) os.unlink(fpath) # process and save as torch files print('# Caching data {}'.format(self.data_file)) dataset = ( read_image_file(self.data_dir, self.image_ext, self.lens[]), read_info_file(self.data_dir, self.info_file), read_matches_files(self.data_dir, self.matches_files) ) with open(self.data_file, 'wb') as f:, f)
def read_image_file(data_dir, image_ext, n): """Return a Tensor containing the patches """ def PIL2array(_img): """Convert PIL image type to numpy 2D array """ return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64) def find_files(_data_dir, _image_ext): """Return a list with the file names of the images containing the patches """ files = [] # find those files with the specified extension for file_dir in os.listdir(_data_dir): if file_dir.endswith(_image_ext): files.append(os.path.join(_data_dir, file_dir)) return sorted(files) # sort files in ascend order to keep relations patches = [] list_files = find_files(data_dir, image_ext) for fpath in list_files: img = for y in range(0, 1024, 64): for x in range(0, 1024, 64): patch = img.crop((x, y, x + 64, y + 64)) patches.append(PIL2array(patch)) return torch.ByteTensor(np.array(patches[:n])) def read_info_file(data_dir, info_file): """Return a Tensor containing the list of labels Read the file and keep only the ID of the 3D point. """ labels = [] with open(os.path.join(data_dir, info_file), 'r') as f: labels = [int(line.split()[0]) for line in f] return torch.LongTensor(labels) def read_matches_files(data_dir, matches_file): """Return a Tensor containing the ground truth matches Read the file and keep only 3D point ID. Matches are represented with a 1, non matches with a 0. """ matches = [] with open(os.path.join(data_dir, matches_file), 'r') as f: for line in f: l = line.split() matches.append([int(l[0]), int(l[3]), int(l[1] == l[4])]) return torch.LongTensor(matches)