How to use torchvision.transforms for data augmentation of segmentation task in Pytorch?

北战南征 提交于 2021-02-09 00:50:49

问题


I am a little bit confused about the data augmentation performed in PyTorch.

Because we are dealing with segmentation tasks, we need data and mask for the same data augmentation, but some of them are random, such as random rotation.

Keras provides a random seed guarantee that data and mask do the same operation, as shown in the following code:

    data_gen_args = dict(featurewise_center=True,
                         featurewise_std_normalization=True,
                         rotation_range=25,
                         horizontal_flip=True,
                         vertical_flip=True)


    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)

    seed = 1
    image_generator = image_datagen.flow(train_data, seed=seed, batch_size=1)
    mask_generator = mask_datagen.flow(train_label, seed=seed, batch_size=1)

    train_generator = zip(image_generator, mask_generator)

I didn't find a similar description in the official Pytorch documentation, so I don't know how to ensure that data and mask can be processed synchronously.

Pytorch does provide such a function, but I want to apply it to a custom Dataloader.

For example:

def __getitem__(self, index):
    img = np.zeros((self.im_ht, self.im_wd, channel_size))
    mask = np.zeros((self.im_ht, self.im_wd, channel_size))

    temp_img = np.load(Image_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_label = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')

    for i in range(channel_size):
        img[:,:,i] = temp_img[self.count[index] + i]
        mask[:,:,i] = temp_label[self.count[index] + i]

    if self.transforms:
        img = np.uint8(img)
        mask = np.uint8(mask)
        img = self.transforms(img)
        mask = self.transforms(mask)

    return img, mask

In this case, img and mask will be transformed separately, because some operations such as random rotation are random, so the correspondence between mask and image may be changed. In other words, the image may have rotated but the mask did not do this.

EDIT 1

I used the method in augmentations.py, but I got an error::

Traceback (most recent call last):
  File "test_transform.py", line 87, in <module>
    for batch_idx, image, mask in enumerate(train_loader):
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 314, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 314, in <listcomp>
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 103, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/home/dirk/home/data/dirk/segmentation_unet_pytorch/data.py", line 164, in __getitem__
    img, mask = self.transforms(img, mask)
  File "/home/dirk/home/data/dirk/segmentation_unet_pytorch/augmentations.py", line 17, in __call__
    img, mask = a(img, mask)
TypeError: __call__() takes 2 positional arguments but 3 were given

This is my code for __getitem__()

data_transforms = {
    'train': Compose([
        RandomHorizontallyFlip(),
        RandomRotate(degree=25),
        transforms.ToTensor()
    ]),
}

train_set = DatasetUnetForTestTransform(fold=args.fold, random_index=args.random_index,transforms=data_transforms['train'])

# __getitem__ in class DatasetUnetForTestTransform
def __getitem__(self, index):
    img = np.zeros((self.im_ht, self.im_wd, channel_size))
    mask = np.zeros((self.im_ht, self.im_wd, channel_size))
    temp_img = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_label = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_img, temp_label = crop_data_label_from_0(temp_img, temp_label)
    for i in range(channel_size):
        img[:,:,i] = temp_img[self.count[index] + i]
        mask[:,:,i] = temp_label[self.count[index] + i]

    if self.transforms:
        img = T.ToPILImage()(np.uint8(img))
        mask = T.ToPILImage()(np.uint8(mask))
        img, mask = self.transforms(img, mask)

    img = T.ToTensor()(img).copy()
    mask = T.ToTensor()(mask).copy()
    return img, mask

EDIT 2

I found that after ToTensor, the dice between the same labels becomes 255 instead of 1, how to fix it?

# Dice computation
def DSC_computation(label, pred):
    pred_sum = pred.sum()
    label_sum = label.sum()
    inter_sum = np.logical_and(pred, label).sum()
    return 2 * float(inter_sum) / (pred_sum + label_sum)

Feel free to ask if more code is needed to explain the problem.


回答1:


torchvision also provides similar functions [document].

Here is a simple example,

import torchvision
from torchvision import transforms

trans = transforms.Compose([transforms.CenterCrop((178, 178)),
                                    transforms.Resize(128),
                                    transforms.RandomRotation(20),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dset = torchvision.datasets.MNIST(data_root, transforms=trans)

EDIT

A brief example when customizing your own CelebA dataset. Note that, to apply transformations, you need call transform list in __getitem__.

class CelebADataset(Dataset):
    def __init__(self, root, transforms=None, num=None):
        super(CelebADataset, self).__init__()

        self.img_root = os.path.join(root, 'img_align_celeba')
        self.attr_root = os.path.join(root, 'Anno/list_attr_celeba.txt')
        self.transforms = transforms

        df = pd.read_csv(self.attr_root, sep='\s+', header=1, index_col=0)
        #print(df.columns.tolist())
        if num is None:
            self.labels = df.values
            self.img_name = df.index.values
        else:
            self.labels = df.values[:num]
            self.img_name = df.index.values[:num]

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_root, self.img_name[index]))
        # only use blond_hair, eyeglass, male, smile
        indices = [9, 15, 20, 31]
        label = np.take(self.labels[index], indices)
        label[label==-1] = 0

        if self.transforms is not None:
            img = self.transforms(img)

        return np.asarray(img), label

    def __len__(self):
        return len(self.labels)


EDIT 2

I probably miss something at the first glance. The main point of your problem is how to apply "the same" data preprocessing to img and labels. To my understanding, there is no available Pytorch built-in function. So, what I did before is to implement the augmentation by myself.

class RandomRotate(object):
    def __init__(self, degree):
        self.degree = degree

    def __call__(self, img, mask):
        rotate_degree = random.random() * 2 * self.degree - self.degree
        return img.rotate(rotate_degree, Image.BILINEAR), 
                           mask.rotate(rotate_degree, Image.NEAREST)

Note that the input should be PIL format. See this for more information.




回答2:


Transforms which require input parameters like RandomCrop has a get_param method which would return the parameters for that particular transformation. This can be then applied to both the image and mask using the functional interface of transforms:

from torchvision import transforms
import torchvision.transforms.functional as F

i, j, h, w = transforms.RandomCrop.get_params(input, (100, 100))
input = F.crop(input, i, j, h, w)
target = F.crop(target, i, j, h, w)

Sample available here: https://github.com/pytorch/vision/releases/tag/v0.2.0

Complete example available here for VOC & COCO: https://github.com/pytorch/vision/blob/master/references/segmentation/transforms.py https://github.com/pytorch/vision/blob/master/references/segmentation/train.py

Regarding the error,

ToTensor() was not overridden to handle additional mask argument, so it cannot be in data_transforms. Moreover, __getitem__ does ToTensor of both img and mask before returning them.

data_transforms = {
    'train': Compose([
        RandomHorizontallyFlip(),
        RandomRotate(degree=25),
        #transforms.ToTensor()  => remove this line
    ]),
}


来源:https://stackoverflow.com/questions/58215056/how-to-use-torchvision-transforms-for-data-augmentation-of-segmentation-task-in

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!