GitHub开源项目Hyperspectral-Classification的解析

余生颓废 提交于 2019-11-27 05:51:25

GitHub链接:Hyperspectral-Classification Pytorch

项目简介

项目的作者是Xidian university,是基于PyTorch的高光谱图像地物目标的分类程序。该项目兼容Python 2.7和Python 3.5+,基于PyTorch深度学习和GPU计算框架,并使用Visdom可视化服务器。

预定义的公开的数据集有:

  • 帕维亚大学
  • 帕维亚中心
  • 肯尼迪航天中心
  • 印度松树
  • 博茨瓦纳

用户也可添加自定义的数据集,示例是“数据融合大赛2018的高光谱数据集”DFC2018_HSI。开发人员应该为CUSTOM_DATASETS_CONFIG变量添加一个新条目,并为其用例定义特定的数据加载器。

该工具实现了scikit-learn库中的几个SVM变体以及PyTorch中实现的许多最先进的深度网络:

  • SVM(带网格搜索的线性,RBF和多核)
  • SGD(使用随机梯度下降的线性SVM进行快速优化)
    基线神经网络(4个完全连接的层,有丢失)
  • 1D CNN(用于高光谱图像分类的深度卷积神经网络,Hu等人,Journal of Sensors 2015)
  • 半监督的1D CNN(Autoencodeurs pour la visualization d’images hyperspectrales,Boulch et al。,GRETSI 2017)
  • 2D CNN(用于图像分类和频带选择的高光谱CNN,应用于人脸识别,Sharma等,技术报告2018)
  • 半监督2D CNN(用于高光谱图像分类的半监督卷积神经网络,Liu等,遥感信函2017)
  • 3D CNN(用于遥感图像分类的三维深度学习方法,Hamida等,TGRS 2018)
  • 3D FCN(基于上下文深度CNN的高光谱分类,Lee和Kwon,IGARSS 2016)
  • 3D CNN(基于卷积神经网络的深度特征提取和高光谱图像分类,Chen等,TGRS 2016)
  • 3D CNN(三维卷积神经网络的高光谱图像的光谱 - 空间分类,Li等,遥感2017)
  • 3D CNN(HSI-CNN:用于高光谱图像的新型卷积神经网络,Luo等,ICPR 2018)
  • 多尺度3D CNN(用于高光谱图像分类的多尺度3D深度卷积神经网络,He等,ICIP 2017)

用户也可以通过修改models.py文件来添加自定义深层网络。这意味着为自定义深层网络创建一个新类并更改该get_model功能。

项目各模块和函数的解析

utils.py


get_device(ordinal)

功能:

根据输入参数,判断device为CPU或GPU。

输入和输出:

输入:

  • ordinal:一个int类型的数,表示用哪个GPU

输出:

  • device:一个超参数,表示运算的位置(CPU or GPU)
代码:
def get_device(ordinal):
    # Use GPU ?
    if ordinal < 0:
        print("Computation on CPU")
        device = torch.device('cpu')
    elif torch.cuda.is_available():
        print("Computation on CUDA GPU device {}".format(ordinal))
        device = torch.device('cuda:{}'.format(ordinal))
    else:
        print("/!\\ CUDA was requested but is not available! Computation will go on CPU. /!\\")
        device = torch.device('cpu')
    return device
解析:

其实就是一个简单的分支结构:

  • ordinal < 0:CPU
  • ordinal < 0orch.cuda.is_available() == True:GPU
  • ordinal < 0orch.cuda.is_available() == False:CPU

open_file(dataset)

功能:

打开指定的数据集的文件。

输入和输出:

输入:

  • dataset:数据集文件的完整路径,比如C:\Datasets\OwnData\OwnData.mat

输出(以读取.mat为例,因为读取的以.mat文件居多):

  • 一个以变量名,以数据的字典dictionary
代码:
def open_file(dataset):
    _, ext = os.path.splitext(dataset)
    ext = ext.lower()
    if ext == '.mat':
        # Load Matlab array
        return io.loadmat(dataset)
    elif ext == '.tif' or ext == '.tiff':
        # Load TIFF file
        return misc.imread(dataset)
    elif ext == '.hdr':
        img = spectral.open_image(dataset)
        return img.load()
    else:
        raise ValueError("Unknown file format: {}".format(ext))
解析:

最重要的是 _, ext = os.path.splitext(dataset)中的os.path.splitext(path)函数。
该函数将输入的路径path拆分为文件名 + 扩展名,并依次作为返回值。_, ext表示只获取扩展名,存入变量ext。之后就是根据不同的扩展名选择不同的打开方式。

需要注意的是,打开.mat文件,返回值是一个以变量名,以数据的字典dictionary。要取出其中的数据,需要通过字典操作,通过访问来获取,比如img = open_file(folder + 'OwnData.mat')['Data']


标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!