Torchvision Dataset
Loading demo ImageNet vision dataset in torchvision using Pytorch. Click here to download the dataset by signing up.
Python3
# import the torch and # torchvision dataset packages. import torch import torchvision # access the dataset in torchvision package using # .datasets followed by dataset name. imagenet_data = torchvision.datasets.ImageNet( 'path/to/imagenet_root/' ) |
Code Explanation:
- The procedure is almost the same as loading the audio data.
- Here, instead of torchaudio, torchvision has to be imported.
- Use the torchvision function with the datasets accessor, followed by the dataset name.
- Now, pass the path in which the dataset is present. Since the ImageNet dataset is no longer publicly accessible, download the root data in your local system and pass the path to this function. This will comfortably load the vision data.
To load your custom image data, use torch.utils.data.DataLoader(data, batch_size, shuffle) as mentioned above.
Python3
# import necessary function # from torchvision package from torchvision import transforms, datasets import matplotlib.pyplot as plt # specify the image dataset folder data_dir = r 'path to dataset\train' # perform some transformations like resizing, # centering and tensorconversion # using transforms function transform = transforms.Compose( [transforms.Resize( 255 ), transforms.CenterCrop( 224 ), transforms.ToTensor()]) # pass the image data folder and # transform function to the datasets # .imagefolder function dataset = datasets.ImageFolder(data_dir, transform = transform) # now use dataloder function load the # dataset in the specified transformation. dataloader = torch.utils.data.DataLoader(dataset, batch_size = 32 , shuffle = True ) # iter function iterates through all the # images and labels and stores in two variables images, labels = next ( iter (dataloader)) # print the total no of samples print ( 'Number of samples: ' , len (images)) image = images[ 2 ][ 0 ] # load 3rd sample # visualize the image plt.imshow(image, cmap = 'gray' ) # print the size of image print ( "Image Size: " , image.size()) # print the label print (label) |
Output:
Image size: torch.Size([224,224]) tensor([0, 0, 0, 1, 1, 1])
Loading Data in Pytorch
In this article, we will discuss how to load different kinds of data in PyTorch.
For demonstration purposes, Pytorch comes with 3 divisions of datasets namely torchaudio, torchvision, and torchtext. We can leverage these demo datasets to understand how to load Sound, Image, and text data using Pytorch.
Contact Us