MATLAB环境下生成对抗网络系列(11种)
MATLAB环境下生成对抗网络系列(11种)
为了构建有效的图像深度学习模型,数据增强是一个非常行之有效的方法。图像的数据增强是一套使用有限数据来提高训练数据集质量和规模的数据空间解决方案。广义的图像数据增强算法包括:几何变换、颜色空间增强、核滤波器、混合图像、随机擦除、特征空间增强、对抗训练、生成对抗网络和风格迁移等内容。增强的数据代表一个分布覆盖性更广、可靠性更高的数据点集,使用增强数据能够有效增加训练样本的多样性,最小化训练集和验证集以及测试集之间的距离。使用数据增强后的数据集训练模型,可以达到提升模型稳定性、泛化能力的效果。
使用生成对抗网络GAN提取原数据集特征,对抗生成新的目标域图像,已成为众多学者在数据增强技术研究中的优选方法。相比于传统的图像数据增强方法,通过基于GAN的生成式建模技术进行数据增强的思路来源于博弈论中的二人零和博弈,由网络中包含的生成器和判别器利用对抗学习的方法来指导网络训练,在两个网络对抗过程中估计原始数据样本的分布并生成与之相似的新数据。
近期的研究在原始生成对抗网络框架的基础上又提出了多种不同的改进方案,通过设计不同的神经网络架构和损失函数等手段不断提升生成对抗网络的性能。生成对抗网络应用在图像数据增强任务上的思想主要是其通过生成新的训练数据来扩充模型的训练数据,通过样本空间的扩充实现图像分类等任务效果的提升。但目前基于GAN的图像数据增强技术普遍存在模型收敛不稳定、生成图像质量低等问题,如何正确引入高频信息,提升图像数据质量是破解这一系列问题的关键。
MATLAB环境配置如下:
- MATLAB 2021b
- Deep Learning Toolbox
- Parallel Computing Toolbox (optional for GPU usage)
目录如下
- Generative Adversarial Network (GAN) [paper]
- Least Squares Generative Adversarial Network (LSGAN) [paper]
- Deep Convolutional Generative Adversarial Network (DCGAN) [paper]
- Conditional Generative Adversarial Network (CGAN)[paper]
- Auxiliary Classifier Generative Adversarial Network (ACGAN) [paper]
- InfoGAN [paper]
- Adversarial AutoEncoder (AAE)[paper]
- Pix2Pix[paper]
- Wasserstein Generative Adversarial Network (WGAN) [paper]
- Semi-Supervised Generative Adversarial Network (SGAN) [paper]
- CycleGAN [paper]
- DiscoGAN [paper]
部分代码如下:
首先,导入相关的mnist手写数字图
load('mnistAll.mat')
然后对训练、测试图像进行预处理
trainX = preprocess(mnist.train_images);
trainY = mnist.train_labels;%训练标签
testX = preprocess(mnist.test_images);
testY = mnist.test_labels;%测试标签
preprocess为归一化函数,如下
function x = preprocess(x)
x = double(x)/255;
x = (x-.5)/.5;
x = reshape(x,28*28,[]);
end
然后进行参数设置,包括潜变量空间维度,batch_size大小,学习率,最大迭代次数等等
settings.latent_dim = 10;
settings.batch_size = 32; settings.image_size = [28,28,1];
settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
settings.beta2 = 0.999; settings.maxepochs = 50;
下面进行编码器初始化,代码还是很容易看懂的
paramsEn.FCW1 = dlarray(initializeGaussian([512,...
prod(settings.image_size)],.02));
paramsEn.FCb1 = dlarray(zeros(512,1,'single'));
paramsEn.FCW2 = dlarray(initializeGaussian([512,512]));
paramsEn.FCb2 = dlarray(zeros(512,1,'single'));
paramsEn.FCW3 = dlarray(initializeGaussian([2*settings.latent_dim,512]));
paramsEn.FCb3 = dlarray(zeros(2*settings.latent_dim,1,'single'));
解码器初始化
paramsDe.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDe.FCb1 = dlarray(zeros(512,1,'single'));
paramsDe.FCW2 = dlarray(initializeGaussian([512,512]));
paramsDe.FCb2 = dlarray(zeros(512,1,'single'));
paramsDe.FCW3 = dlarray(initializeGaussian([prod(settings.image_size),512]));
paramsDe.FCb3 = dlarray(zeros(prod(settings.image_size),1,'single'));
判别器初始化
paramsDis.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDis.FCb1 = dlarray(zeros(512,1,'single'));
paramsDis.FCW2 = dlarray(initializeGaussian([256,512]));
paramsDis.FCb2 = dlarray(zeros(256,1,'single'));
paramsDis.FCW3 = dlarray(initializeGaussian([1,256]));
paramsDis.FCb3 = dlarray(zeros(1,1,'single'));
%平均梯度和平均梯度平方数组
avgG.Dis = []; avgGS.Dis = []; avgG.En = []; avgGS.En = [];
avgG.De = []; avgGS.De = [];
开始训练
dlx = gpdl(trainX(:,1),'CB');
dly = Encoder(dlx,paramsEn);
numIterations = floor(size(trainX,2)/settings.batch_size);
out = false; epoch = 0; global_iter = 0;
while ~out
tic;
shuffleid = randperm(size(trainX,2));
trainXshuffle = trainX(:,shuffleid);
fprintf('Epoch %d\n',epoch)
for i=1:numIterations
global_iter = global_iter+1;
idx = (i-1)*settings.batch_size+1:i*settings.batch_size;
XBatch=gpdl(single(trainXshuffle(:,idx)),'CB');
[GradEn,GradDe,GradDis] = ...
dlfeval(@modelGradients,XBatch,...
paramsEn,paramsDe,paramsDis,settings);
% 更新判别器网络参数
[paramsDis,avgG.Dis,avgGS.Dis] = ...
adamupdate(paramsDis, GradDis, ...
avgG.Dis, avgGS.Dis, global_iter, ...
settings.lrD, settings.beta1, settings.beta2);
% 更新编码器网络参数
[paramsEn,avgG.En,avgGS.En] = ...
adamupdate(paramsEn, GradEn, ...
avgG.En, avgGS.En, global_iter, ...
settings.lrG, settings.beta1, settings.beta2);
% 更新解码器网络参数
[paramsDe,avgG.De,avgGS.De] = ...
adamupdate(paramsDe, GradDe, ...
avgG.De, avgGS.De, global_iter, ...
settings.lrG, settings.beta1, settings.beta2);
if i==1 || rem(i,20)==0
progressplot(paramsDe,settings);
if i==1
h = gcf;
% 捕获图像
frame = getframe(h);
im = frame2im(frame);
[imind,cm] = rgb2ind(im,256);
% 写入 GIF 文件
if epoch == 0
imwrite(imind,cm,'AAEmnist.gif','gif', 'Loopcount',inf);
else
imwrite(imind,cm,'AAEmnist.gif','gif','WriteMode','append');
end
end
end
end
elapsedTime = toc;
disp("Epoch "+epoch+". Time taken for epoch = "+elapsedTime + "s")
epoch = epoch+1;
if epoch == settings.maxepochs
out = true;
end
end
下面是完整的辅助函数
模型的梯度计算函数
function [GradEn,GradDe,GradDis]=modelGradients(x,paramsEn,paramsDe,paramsDis,settings)
dly = Encoder(x,paramsEn);
latent_fake = dly(1:settings.latent_dim,:)+...
dly(settings.latent_dim+1:2*settings.latent_dim)*...
randn(settings.latent_dim,settings.batch_size);
latent_real = gpdl(randn(settings.latent_dim,settings.batch_size),'CB');
%训练判别器
d_output_fake = Discriminator(latent_fake,paramsDis);
d_output_real = Discriminator(latent_real,paramsDis);
d_loss = -.5*mean(log(d_output_real+eps)+log(1-d_output_fake+eps));
%训练编码器和解码器
x_ = Decoder(latent_fake,paramsDe);
g_loss = .999*mean(mean(.5*(x_-x).^2,1))-.001*mean(log(d_output_fake+eps));
%对于每个网络,计算关于损失函数的梯度
[GradEn,GradDe] = dlgradient(g_loss,paramsEn,paramsDe,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);
end
提取数据函数
function x = gatext(x)
x = gather(extractdata(x));
end
GPU深度学习数组wrapper函数
function dlx = gpdl(x,labels)
dlx = gpuArray(dlarray(x,labels));
end
权重初始化函数
function parameter = initializeGaussian(parameterSize,sigma)
if nargin < 2
sigma = 0.05;
end
parameter = randn(parameterSize, 'single') .* sigma;
end
dropout函数
function dly = dropout(dlx,p)
if nargin < 2
p = .3;
end
[n,d] = rat(p);
mask = randi([1,d],size(dlx));
mask(mask<=n)=0;
mask(mask>n)=1;
dly = dlx.*mask;
end
编码器函数
function dly = Encoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
end
解码器函数
function dly = Decoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
dly = tanh(dly);
end
判别器函数
function dly = Discriminator(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = sigmoid(dly);
end
工学博士,担任《Mechanical System and Signal Processing》审稿专家,担任
《中国电机工程学报》优秀审稿专家,《控制与决策》,《系统工程与电子技术》,《电力系统保护与控制》,《宇航学报》等EI期刊审稿专家。
擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。
更多推荐
所有评论(0)