pytorch只用中要注意通道问题

删除回忆录丶 提交于 2019-11-27 08:31:01

cv读进来的是BGR图像,通道是hcw,在torch中使用要注意维度转换

def __getitem__(self, idx):
        '''Load image.

        Args:
          idx: (int) image index.
       img_org = Image.open(self.root_src  +  'reference_cutBlock' + fname_org)
        Returns:
          img: (tensor) image tensor.
          loc_targets: (tensor) location targets.
          cls_targets: (tensor) class label targets.
        '''
        # Load image
        fname_org = self.fnames[idx]
        img_org = cv2.imread(self.root_src + 'dn_dataset/' + fname_org)
        # img_org = np.asarray(img_org)

        coin = np.random.randint(0, 50)
        img_dis = skimage.util.random_noise(img_org, mode='gaussian', seed=None,
                                            var=(coin / 255.0) ** 2)  # add  gaussian noise

        # img_dis = img_dis[:, :, (2, 1, 0)]  # bgr012 to rgb210
        img_dis = img_dis.transpose([2, 0, 1])  # hwc to chw
        img_dis = img_dis[(2, 1, 0), :, :]  # bgr012 to rgb210

        img_org = img_org[:, :, (2, 1, 0)]/255.0  # bgr012 to rgb210
        img_org = img_org.transpose([2, 0, 1])  # hwc to chw

        img_dis = torch.from_numpy(img_dis).float()
        img_org = torch.from_numpy(img_org).float()
        # fname_org_dis = self.fnames_dis[idx]
        # img_dis = Image.open(self.root_src  +  'distorted_train_block/' + fname_org_dis)

        # if img_org.mode != 'RGB':
        #     img_org = img_org.convert('RGB')
        #
        # if img_dis.mode != 'RGB':
        #     img_dis = img_dis.convert('RGB')
        # img_org = self.transform(img_org)
        # img_dis = self.transform(img_dis)

        return img_dis, img_org

transforms.ToTensor() 有两层含义,一个是转化成Tensor,另一个是进行归一化,此段代码,没有采用此语句,而是分两步完成,因为img_dis,已经实现归一化。

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!