本模型是基于DnCNN-keras进行的微小调整,主要是将其模型单输入无监督学习(输入纯净图片进行噪声添加并学习去噪),修改成输入纯净图片与真实的噪声图片进行监督学习,使其去噪能力更为专一与准确。

main.py

首先对迭代函数进行修改:

原代码:

def train_datagen(y_, batch_size=8):
    indices = list(range(y_.shape[0]))
    while(True):
        np.random.shuffle(indices)    
        for i in range(0, len(indices), batch_size):
            ge_batch_y = y_[indices[i:i+batch_size]]
            noise =  np.random.normal(0, args.sigma/255.0, ge_batch_y.shape)  
            ge_batch_x = ge_batch_y + noise 
            yield ge_batch_x, ge_batch_y

修改后:

def train_datagen(x_, y_, batch_size=8):#添加x_形参
    indices = list(range(y_.shape[0]))
    while (True):
        np.random.shuffle(indices)
        for i in range(0, len(indices), batch_size):
            ge_batch_x = x_[indices[i:i + batch_size]] # 添加对x的batch处理
            ge_batch_y = y_[indices[i:i + batch_size]]
            # noise =  np.random.normal(0, args.sigma/255.0, ge_batch_y.shape)    # noise
            # #noise =  K.random_normal(ge_batch_y.shape, mean=0, stddev=args.sigma/255.0)
            # ge_batch_x = ge_batch_y + noise  # input image = clean image + noise
            #将加噪声部分部分注释掉
            yield ge_batch_x, ge_batch_y 

对训练函数进行修改:

原函数:

ef train():

    data = load_train_data()
    data = data.reshape((data.shape[0],data.shape[1],data.shape[2],1))
    data = data.astype('float32')/255.0
    if args.pretrain:   model = load_model(args.pretrain, compile=False)
    else:   
        if args.model == 'DnCNN': model = models.DnCNN()
    model.compile(optimizer=Adam(), loss=['mse'])
    ckpt = ModelCheckpoint(save_dir+'/model_{epoch:02d}.h5', monitor='val_loss', 
                    verbose=0, period=args.save_every)
    csv_logger = CSVLogger(save_dir+'/log.csv', append=True, separator=',')
    lr = LearningRateScheduler(step_decay)
    history = model.fit_generator(train_datagen(data, batch_size=args.batch_size),
                    steps_per_epoch=len(data)//args.batch_size, epochs=args.epoch, verbose=1, 
                    callbacks=[ckpt, csv_logger, lr])

    return model

修改后:

def train():
    datax = load_train_data(args.train_datax)
    datax = datax.reshape((datax.shape[0], datax.shape[1], datax.shape[2], 1))
    datax = datax.astype('float32') / 255.0
    # 复制粘贴给标签腾地方
    datay = load_train_data(args.train_datay)
    datay = datay.reshape((datay.shape[0], datay.shape[1], datay.shape[2], 1))
    datay = datay.astype('float32') / 255.0

    if args.pretrain:
        model = load_model(args.pretrain, compile=False)
    else:
        if args.model == 'DnCNN': model = models.DnCNN()

    model.compile(optimizer=Adam(), loss=['mse'])
    ckpt = ModelCheckpoint(save_dir + '/model_{epoch:02d}.h5', monitor='val_loss',
                           verbose=0, period=args.save_every)
    csv_logger = CSVLogger(save_dir + '/log.csv', append=True, separator=',')
    lr = LearningRateScheduler(step_decay)
    history = model.fit_generator(train_datagen(datax, datay, batch_size=args.batch_size),
                                  steps_per_epoch=len(datax) // args.batch_size, epochs=args.epoch, verbose=1,
                                  callbacks=[ckpt, csv_logger, lr]) # 添加输入对应形参
    return model

修改测试函数:

原函数:

def test(model):

  # 略
    file_list = glob.glob('{}/*.png'.format(args.test_dir))
    for file in file_list:
        # read image
        img_clean = np.array(Image.open(file), dtype='float32') / 255.0
        img_test = img_clean + np.random.normal(0, args.sigma/255.0, img_clean.shape)
        img_test = img_test.astype('float32')
        # predict
  # 略
    pd.DataFrame({'name':np.array(name), 'psnr':np.array(psnr), 'ssim':np.array(ssim)}).to_csv(out_dir+'/metrics.csv', index=True)

修改后:

def test(model):

  # 略
    file_list = glob.glob('{}/*.png'.format(args.test_dir))
    for file in file_list:
        # read image
        img_clean = np.array(Image.open(file), dtype='float32') / 255.0
        # img_test = img_clean + np.random.normal(0, args.sigma / 255.0, img_clean.shape)
        # 将噪声注释掉
        img_test = img_clean.astype('float32')
        # predict
  # 略
    pd.DataFrame({'name':np.array(name), 'psnr':np.array(psnr), 'ssim':np.array(ssim)}).to_csv(out_dir+'/metrics.csv', index=True)

对应train函数,添加预设dir:

parser.add_argument('--train_datax', default='./data/npy_data/x.npy', type=str, help='path of train data')
parser.add_argument('--train_datay', default='./data/npy_data/y.npy', type=str, help='path of train data')

data.py

分别打包噪声图片和标签就好了。

还没有跑这个模型,如果失败了会过来继续修改文章。

如果成功了,那就当做一篇记录自己修改(学长帮助)模型的经验记录吧。