【AI论文复现】Pansharpening via Detail Injection Based Convolutional Neural Networks

一、网络结构

该篇论文的DiCNN1 网络结构图如下:

在这里插入图片描述

二、代码实现

2.1、基于TensorFlow1.14.0实现核心代码

########## DiCNN1Net structures ################
def DiCNN1Net(lms, pan, num_spectral = 8, num_res = 4, num_fm = 32, reuse=False):
    
    weight_decay = 1e-5

    with tf.variable_scope('net'):        
        if reuse:
            tf.get_variable_scope().reuse_variables()


        ms_1 = tf.concat([lms,pan],axis=3)

        rs = ly.conv2d(ms_1, num_outputs = num_fm, kernel_size = 3, stride = 1,
                          weights_regularizer = ly.l2_regularizer(weight_decay), 
                          weights_initializer = ly.variance_scaling_initializer(),
                          activation_fn = tf.nn.relu)

        rs = ly.conv2d(rs, num_outputs = num_fm, kernel_size = 3, stride = 1,
                          weights_regularizer = ly.l2_regularizer(weight_decay),
                          weights_initializer = ly.variance_scaling_initializer(),
                          activation_fn = tf.nn.relu)

        rs = ly.conv2d(rs, num_outputs = num_spectral, kernel_size = 3, stride = 1,
                          weights_regularizer = ly.l2_regularizer(weight_decay),
                          weights_initializer = ly.variance_scaling_initializer(),
                          activation_fn = None)
            
        rs = tf.add(rs,lms)

        return rs

2.2、基于pytorch1.7.1实现核心代码

########## DiCNN1Net structures ################
class DiCNN1Net(nn.Module):
    def __init__(self):
        super(PanNet, self).__init__()

        channel = 32
        spectral_num = 8

        self.conv1 = nn.Conv2d(in_channels=spectral_num + 1, out_channels=channel, kernel_size=3, stride=1, padding=1,
                               bias=True) #输入9 输出32

        self.conv2 = nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=3, stride=1, padding=1,
                               bias=True) #输入32 输出32

        self.conv3 = nn.Conv2d(in_channels=channel, out_channels=spectral_num, kernel_size=3, stride=1, padding=1,
                               bias=True) #输入32 输出8

        self.relu = nn.ReLU(inplace=True)

        # init_weights(self.conv1, self.conv2, self.conv3)   # state initialization, important!

    def forward(self, x, y):  # x= lms; y = pan 该函数才是网络的流程!

        input = torch.cat([x, y], 1)  # Bsx9x64x64 channel在第1个位置

        rs1 = self.relu(self.conv1(input))  # Bsx32x64x64

        rs2 = self.relu(self.conv2(rs1))  # Bsx32x64x64

        rs3 = self.conv3(rs2)  # Bsx8x64x64

        output = torch.add(x, rs3)  # Bsx8x64x64

        return output

卷积层初始化(init_weights),可以先不使用,torch会默认完成对应的初始化(默认是kaiming init),这里PanNet使用是因为 需要如此调试,一般默认的就够用,不用增加麻烦。当然在实际中,也有可能更改初始化,在引进此处的代码即可。

2.3、MATLAB图像显示代码

训练好模型后,测试输出的是一个.mat格式的文件,该文件是一个256x256x8的图像数据,在日常生活中我们一般使用的图像数据是RGB三个通道的,仿照程序已有的代码,在八个通道中抽出三个通道作为RGB…

def vis_ms(data):
    _,b,g,_,r,_,_,_ = tf.split(data,8,axis = 3)
    vis = tf.concat([r,g,b],axis = 3)
    return vis

由已有代码可知,R通道是该数据的第5个通道,G通道是该数据的第3个通道,B通道是该数据的第2个通道。有了RGB三通道的图像数据,我们就很方便的可以显示图像了,matlab代码如下所示:

clc;
clear;
close all;

load('output.mat');
img_r = output(:,:,5);
img_g = output(:,:,3);
img_b = output(:,:,2);
img(:,:,1) = img_r;
img(:,:,2) = img_g;
img(:,:,3) = img_b;
imshow(img)

subplot(1,4,1),imshow(img);title(['RGB图']);
subplot(1,4,2),imshow(img_r );title(['R图']);
subplot(1,4,3),imshow(img_g );title(['G图']);
subplot(1,4,4),imshow(img_b );title(['B图']);

在这里插入图片描述


参考

ReCclay CSDN认证博客专家 嵌入式软件开发 机器/深度学习 全栈技术学习者
大家好,我是CSDN博主ReCclay,目前处于研究生阶段,就读于电子科技大学,主攻方向为汽车辅助驾驶算法研究。入站以来,凭借坚持与热爱,以博文的方式分享所学,截止目前累计博文数量达800余篇,累计受益人次达130w+次,涉及领域包括但不限于物联网开发、单片机开发、Linux驱动开发、FPGA开发、前/后端软件开发等。在未来我将继续专注于嵌入式相关领域,学习更多的科技知识,输出更高质量的博文。
已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 数字20 设计师:CSDN官方博客 返回首页
实付 29.90元
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值