How to iterate over two dataloaders simultaneously using pytorch?

后端 未结 5 679
星月不相逢
星月不相逢 2021-02-06 06:04

I am trying to implement a Siamese network that takes in two images. I load these images and create two separate dataloaders.

In my loop I want to go through both datalo

5条回答
  •  自闭症患者
    2021-02-06 06:40

    Adding on @Aldream's solution for the case when we have varying length of the dataset and if we want to pass through them all at same epoch then we could use the cycle() from itertools, a Python Standard library. Using the code snippet of @Aldrem, the updated code will look like:

    from torch.utils.data import DataLoader, Dataset
    from itertools import cycle
    
    class DummyDataset(Dataset):
        """
        Dataset of numbers in [a,b] inclusive
        """
    
        def __init__(self, a=0, b=100):
            super(DummyDataset, self).__init__()
            self.a = a
            self.b = b
    
        def __len__(self):
            return self.b - self.a + 1
    
        def __getitem__(self, index):
            return index
    
    dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
    dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
    num_epochs = 10
    
    for epoch in range(num_epochs):
        for i, data in enumerate(zip(cycle(dataloaders1), dataloaders2)):
            print(data)
    

    With only zip() the iterator will be exhausted when the length is equal to that of the smallest dataset (here 100). But with the use of cycle(), we will repeat the smallest dataset again unless our iterator looks at all the samples from the largest dataset (here 200).

    P.S. One can always argue this approach may not be required to achieve convergence as long as one does samples randomly but with this approach, the evaluation might be easier.

提交回复
热议问题