高光谱图像分类--HybridSN: Exploring 3-D–2-DCNN Feature Hierarchy for Hyperspectral Image Classification

《HybridSN: Exploring 3-D–2-DCNN Feature Hierarchy for Hyperspectral Image Classification》

S. K. Roy, G. Krishna, S. R. Dubey, B. B. Chaudhuri HybridSN: Exploring 3-D–2-D CNN Feature Hierarchy for Hyperspectral Image Classification, IEEE GRSL 2020

这篇论文构建了一个 混合网络 (3D卷积+2D卷积)解决高光谱图像分类问题。
Q:3D卷积和2D卷积的区别?
首先两者都是多通道卷积,而3D卷积比2D卷积多一个深度信息,本质上是一样的,因为2D卷积可以默认深度信息为1。
2D卷积 shape:(batch_size, channel, height, weight)
3D卷积shape:(batch_size, channel, depth, height, weight)

模型网络结构

在这里插入图片描述

三维卷积部分:
conv1:(1, 30, 25, 25), 8个 7x3x3 的卷积核 ==>(8, 24, 23, 23)
conv2:(8, 24, 23, 23), 16个 5x3x3 的卷积核 ==>(16, 20, 21, 21)
conv3:(16, 20, 21, 21),32个 3x3x3 的卷积核 ==>(32, 18, 19, 19)
二维卷积
把前面的 32*18 reshape 一下,得到 (576, 19, 19)
接下来是一个 flatten 操作,变为 18496 维的向量
接下来依次为256,128节点的全连接层,都使用比例为0.4的 Dropout
最后输出为 16 个节点,是最终的分类类别数

网络结构

class HybridSN(nn.Module):
  def __init__(self):
    super(HybridSN, self).__init__()
    self.conv3d_1 = nn.Sequential(
        nn.Conv3d(1, 8, kernel_size=(7, 3, 3), stride=1, padding=0),
        nn.BatchNorm3d(8),
        nn.ReLU(inplace = True),
    )
    self.conv3d_2 = nn.Sequential(
        nn.Conv3d(8, 16, kernel_size=(5, 3, 3), stride=1, padding=0),
        nn.BatchNorm3d(16),
        nn.ReLU(inplace = True),
    ) 
    self.conv3d_3 = nn.Sequential(
        nn.Conv3d(16, 32, kernel_size=(3, 3, 3), stride=1, padding=0),
        nn.BatchNorm3d(32),
        nn.ReLU(inplace = True)
    )

    self.conv2d_4 = nn.Sequential(
        nn.Conv2d(576, 64, kernel_size=(3, 3), stride=1, padding=0),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace = True),
    )
    self.fc1 = nn.Linear(18496,256)
    self.fc2 = nn.Linear(256,128)
    self.fc3 = nn.Linear(128,16)
    self.dropout = nn.Dropout(p = 0.4)

  def forward(self,x):
    out = self.conv3d_1(x)
    out = self.conv3d_2(out)
    out = self.conv3d_3(out)
    out = self.conv2d_4(out.reshape(out.shape[0],-1,19,19))
    out = out.reshape(out.shape[0],-1)
    out = F.relu(self.dropout(self.fc1(out)))
    out = F.relu(self.dropout(self.fc2(out)))
    out = self.fc3(out)
    return out

开始训练

# 使用GPU训练,可以在菜单 "代码执行工具" -> "更改运行时类型" 里进行设置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 网络放到GPU上
net = HybridSN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# 开始训练
total_loss = 0
for epoch in range(100):
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        # 优化器梯度归零
        optimizer.zero_grad()
        # 正向传播 + 反向传播 + 优化 
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print('[Epoch: %d]   [loss avg: %.4f]   [current loss: %.4f]' %(epoch + 1, total_loss/(epoch+1), loss.item()))

print('Finished Training')

[Epoch: 1] [loss avg: 19.6715] [current loss: 1.9873]
[Epoch: 2] [loss avg: 15.6281] [current loss: 1.0401]
[Epoch: 3] [loss avg: 13.1159] [current loss: 0.8418]
[Epoch: 4] [loss avg: 11.2389] [current loss: 0.5678]
[Epoch: 5] [loss avg: 9.6885] [current loss: 0.3103]
[Epoch: 6] [loss avg: 8.4845] [current loss: 0.2024]
[Epoch: 7] [loss avg: 7.5417] [current loss: 0.1180]
[Epoch: 8] [loss avg: 6.7961] [current loss: 0.2098]
[Epoch: 9] [loss avg: 6.1619] [current loss: 0.2179]
[Epoch: 10] [loss avg: 5.6341] [current loss: 0.0968]
[Epoch: 11] [loss avg: 5.1822] [current loss: 0.1244]
[Epoch: 12] [loss avg: 4.7971] [current loss: 0.0295]
[Epoch: 13] [loss avg: 4.4596] [current loss: 0.0777]
[Epoch: 14] [loss avg: 4.1665] [current loss: 0.0476]
[Epoch: 15] [loss avg: 3.9178] [current loss: 0.0910]
[Epoch: 16] [loss avg: 3.6988] [current loss: 0.0331]
[Epoch: 17] [loss avg: 3.5068] [current loss: 0.0327]
[Epoch: 18] [loss avg: 3.3333] [current loss: 0.0324]
[Epoch: 19] [loss avg: 3.1683] [current loss: 0.0548]
[Epoch: 20] [loss avg: 3.0232] [current loss: 0.0196]
[Epoch: 21] [loss avg: 2.8877] [current loss: 0.0215]
[Epoch: 22] [loss avg: 2.7742] [current loss: 0.0244]
[Epoch: 23] [loss avg: 2.6641] [current loss: 0.0139]
[Epoch: 24] [loss avg: 2.5632] [current loss: 0.0097]
[Epoch: 26] [loss avg: 2.3853] [current loss: 0.0657]
[Epoch: 27] [loss avg: 2.3084] [current loss: 0.0164]
[Epoch: 28] [loss avg: 2.2347] [current loss: 0.0143]
[Epoch: 29] [loss avg: 2.1671] [current loss: 0.0189]
[Epoch: 30] [loss avg: 2.1063] [current loss: 0.0331]
[Epoch: 31] [loss avg: 2.0429] [current loss: 0.0220]
[Epoch: 32] [loss avg: 1.9878] [current loss: 0.0173]
[Epoch: 33] [loss avg: 1.9329] [current loss: 0.0154]
[Epoch: 34] [loss avg: 1.8801] [current loss: 0.0306]
[Epoch: 35] [loss avg: 1.8305] [current loss: 0.0442]
[Epoch: 36] [loss avg: 1.7823] [current loss: 0.0077]
[Epoch: 37] [loss avg: 1.7399] [current loss: 0.0359]
[Epoch: 38] [loss avg: 1.6970] [current loss: 0.0004]
[Epoch: 39] [loss avg: 1.6546] [current loss: 0.0198]
[Epoch: 40] [loss avg: 1.6153] [current loss: 0.0126]
[Epoch: 41] [loss avg: 1.5795] [current loss: 0.0608]
[Epoch: 42] [loss avg: 1.5454] [current loss: 0.0296]
[Epoch: 43] [loss avg: 1.5121] [current loss: 0.0053]
[Epoch: 44] [loss avg: 1.4795] [current loss: 0.0325]
[Epoch: 45] [loss avg: 1.4492] [current loss: 0.0058]
[Epoch: 46] [loss avg: 1.4193] [current loss: 0.0130]
[Epoch: 47] [loss avg: 1.3904] [current loss: 0.0055]
[Epoch: 48] [loss avg: 1.3622] [current loss: 0.0013]
[Epoch: 49] [loss avg: 1.3354] [current loss: 0.0110]
[Epoch: 50] [loss avg: 1.3093] [current loss: 0.0041]
[Epoch: 51] [loss avg: 1.2845] [current loss: 0.0004]
[Epoch: 52] [loss avg: 1.2607] [current loss: 0.0028]
[Epoch: 53] [loss avg: 1.2375] [current loss: 0.0109]
[Epoch: 54] [loss avg: 1.2150] [current loss: 0.0010]
[Epoch: 55] [loss avg: 1.1934] [current loss: 0.0008]
[Epoch: 56] [loss avg: 1.1725] [current loss: 0.0011]
[Epoch: 57] [loss avg: 1.1524] [current loss: 0.0044]
[Epoch: 58] [loss avg: 1.1331] [current loss: 0.0012]
[Epoch: 59] [loss avg: 1.1151] [current loss: 0.0092]
[Epoch: 60] [loss avg: 1.0982] [current loss: 0.0035]
[Epoch: 61] [loss avg: 1.0822] [current loss: 0.0062]
[Epoch: 62] [loss avg: 1.0671] [current loss: 0.0412]
[Epoch: 63] [loss avg: 1.0517] [current loss: 0.0006]
[Epoch: 64] [loss avg: 1.0367] [current loss: 0.0206]
[Epoch: 65] [loss avg: 1.0213] [current loss: 0.0052]
[Epoch: 66] [loss avg: 1.0071] [current loss: 0.0027]
[Epoch: 67] [loss avg: 0.9940] [current loss: 0.0268]
[Epoch: 68] [loss avg: 0.9807] [current loss: 0.0500]
[Epoch: 69] [loss avg: 0.9686] [current loss: 0.0058]
[Epoch: 70] [loss avg: 0.9561] [current loss: 0.0032]
[Epoch: 71] [loss avg: 0.9454] [current loss: 0.0004]
[Epoch: 72] [loss avg: 0.9352] [current loss: 0.0764]
[Epoch: 73] [loss avg: 0.9268] [current loss: 0.0020]
[Epoch: 74] [loss avg: 0.9178] [current loss: 0.0977]
[Epoch: 75] [loss avg: 0.9097] [current loss: 0.0547]
[Epoch: 76] [loss avg: 0.8997] [current loss: 0.0115]
[Epoch: 77] [loss avg: 0.8921] [current loss: 0.0407]
[Epoch: 78] [loss avg: 0.8830] [current loss: 0.0063]
[Epoch: 79] [loss avg: 0.8748] [current loss: 0.0147]
[Epoch: 80] [loss avg: 0.8664] [current loss: 0.0070]
[Epoch: 81] [loss avg: 0.8570] [current loss: 0.0042]
[Epoch: 82] [loss avg: 0.8473] [current loss: 0.0022]
[Epoch: 83] [loss avg: 0.8382] [current loss: 0.0274]
[Epoch: 84] [loss avg: 0.8300] [current loss: 0.0600]
[Epoch: 85] [loss avg: 0.8209] [current loss: 0.0004]
[Epoch: 86] [loss avg: 0.8125] [current loss: 0.0008]
[Epoch: 87] [loss avg: 0.8038] [current loss: 0.0005]
[Epoch: 88] [loss avg: 0.7961] [current loss: 0.0016]
[Epoch: 89] [loss avg: 0.7889] [current loss: 0.0196]
[Epoch: 90] [loss avg: 0.7812] [current loss: 0.0037]
[Epoch: 91] [loss avg: 0.7748] [current loss: 0.0013]
[Epoch: 92] [loss avg: 0.7674] [current loss: 0.0358]
[Epoch: 93] [loss avg: 0.7600] [current loss: 0.0045]
[Epoch: 94] [loss avg: 0.7533] [current loss: 0.0483]
[Epoch: 95] [loss avg: 0.7462] [current loss: 0.0121]
[Epoch: 96] [loss avg: 0.7389] [current loss: 0.0052]
[Epoch: 97] [loss avg: 0.7316] [current loss: 0.0117]
[Epoch: 98] [loss avg: 0.7260] [current loss: 0.0028]
[Epoch: 99] [loss avg: 0.7200] [current loss: 0.0115]
[Epoch: 100] [loss avg: 0.7136] [current loss: 0.0001]
Finished Training

模型测试

count = 0
# 模型测试
for inputs, _ in test_loader:
    inputs = inputs.to(device)
    outputs = net(inputs)
    outputs = np.argmax(outputs.detach().cpu().numpy(), axis=1)
    if count == 0:
        y_pred_test =  outputs
        count = 1
    else:
        y_pred_test = np.concatenate( (y_pred_test, outputs) )

# 生成分类报告
classification = classification_report(ytest, y_pred_test, digits=4)
print(classification)
     precision    recall  f1-score   support
     0.0     0.9444    0.8293    0.8831        41
     1.0     0.9769    0.9525    0.9645      1285
     2.0     0.9723    0.9853    0.9787       747
     3.0     0.9947    0.8779    0.9327       213
     4.0     0.9861    0.9793    0.9827       435
     5.0     0.9877    0.9802    0.9840       657
     6.0     1.0000    0.9600    0.9796        25
     7.0     0.9817    1.0000    0.9908       430
     8.0     0.7727    0.9444    0.8500        18
     9.0     0.9875    0.9909    0.9892       875
    10.0     0.9725    0.9910    0.9816      2210
    11.0     0.9499    0.9588    0.9543       534
    12.0     0.9572    0.9676    0.9624       185
    13.0     0.9887    0.9965    0.9926      1139
    14.0     0.9826    0.9741    0.9783       347
    15.0     0.8933    0.7976    0.8428        84
 accuracy                               0.9767      9225
 macro avg     0.9593    0.9491    0.9529      9225
weighted avg     0.9768    0.9767    0.9765      9225

准确率达到了97.67.
Q:通过上文中训练网络,然后多测试几次,会发现每次分类的结果都不一样,请思考为什么?
A: 网络的全连接层中使用了 nn.Dropout,网络层的节点会随机失活。因此测试的时候没有启用测试模式。
解决方法:

训练时:

net.train() 

测试时:

net.eval()

显示分类结果

# load the original image
X = sio.loadmat('Indian_pines_corrected.mat')['indian_pines_corrected']
y = sio.loadmat('Indian_pines_gt.mat')['indian_pines_gt']

height = y.shape[0]
width = y.shape[1]

X = applyPCA(X, numComponents= pca_components)
X = padWithZeros(X, patch_size//2)

# 逐像素预测类别
outputs = np.zeros((height,width))
for i in range(height):
    for j in range(width):
        if int(y[i,j]) == 0:
            continue
        else :
            image_patch = X[i:i+patch_size, j:j+patch_size, :]
            image_patch = image_patch.reshape(1,image_patch.shape[0],image_patch.shape[1], image_patch.shape[2], 1)
            X_test_image = torch.FloatTensor(image_patch.transpose(0, 4, 3, 1, 2)).to(device)                                   
            prediction = net(X_test_image)
            prediction = np.argmax(prediction.detach().cpu().numpy(), axis=1)
            outputs[i][j] = prediction+1
    if i % 20 == 0:
        print('... ... row ', i, ' handling ... ...')

… … row 0 handling … …
… … row 20 handling … …
… … row 40 handling … …
… … row 60 handling … …
… … row 80 handling … …
… … row 100 handling … …
… … row 120 handling … …
… … row 140 handling … …

predict_image = spectral.imshow(classes = outputs.astype(int),figsize =(5,5))

原始代码结果:
在这里插入图片描述
增加训练测试模式代码结果:
在这里插入图片描述

可以看出分类效果也明显变好了。

加入注意力机制

Q:如果想要进一步提升高光谱图像的分类性能,可以如何使用注意力机制?

class_num = 16
# 通道注意力机制
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
 
        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
 
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)
        
# 空间注意力机制
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
 
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
 
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class HybridSN(nn.Module):
  def __init__(self, num_classes, self_attention=False):
    super(HybridSN, self).__init__()
    self.conv3d_1 = nn.Sequential(
        nn.Conv3d(1, 8, kernel_size=(7, 3, 3), stride=1, padding=0),
        nn.BatchNorm3d(8),
        nn.ReLU(inplace = True),
    )
    
    self.self_attention = self_attention   #默认
    
    self.conv3d_2 = nn.Sequential(
        nn.Conv3d(8, 16, kernel_size=(5, 3, 3), stride=1, padding=0),
        nn.BatchNorm3d(16),
        nn.ReLU(inplace = True),
    ) 
    self.conv3d_3 = nn.Sequential(
        nn.Conv3d(16, 32, kernel_size=(3, 3, 3), stride=1, padding=0),
        nn.BatchNorm3d(32),
        nn.ReLU(inplace = True)
    )

    if self_attention:
            self.channel_attention_1 = ChannelAttention(576)
            self.spatial_attention_1 = SpatialAttention(kernel_size=7)

    self.conv2d_4 = nn.Sequential(
        nn.Conv2d(576, 64, kernel_size=(3, 3), stride=1, padding=0),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace = True),
    )

    if self_attention:
            self.channel_attention_1 = ChannelAttention(576)
            self.spatial_attention_1 = SpatialAttention(kernel_size=7)
    
    self.fc1 = nn.Linear(18496,256)
    self.fc2 = nn.Linear(256,128)
    self.fc3 = nn.Linear(128,16)
    self.dropout = nn.Dropout(p = 0.4)

  def forward(self,x):
    out = self.conv3d_1(x)
    out = self.conv3d_2(out)
    out = self.conv3d_3(out)
    out = self.conv2d_4(out.reshape(out.shape[0],-1,19,19))
    out = out.reshape(out.shape[0],-1)
    out = F.relu(self.dropout(self.fc1(out)))
    out = F.relu(self.dropout(self.fc2(out)))
    out = self.fc3(out)
    return out 

# 随机输入,测试网络结构是否通
# x = torch.randn(1, 1, 30, 25, 25)
# net = HybridSN()
# y = net(x)
# print(y.shape)

测试准确率明显提高到了 98.51
在这里插入图片描述
结果图片:
在这里插入图片描述