CN112862094A - 一种基于元学习的快速适应drbm方法 - Google Patents

一种基于元学习的快速适应drbm方法 Download PDF

Info

Publication number
CN112862094A
CN112862094A CN202110134999.9A CN202110134999A CN112862094A CN 112862094 A CN112862094 A CN 112862094A CN 202110134999 A CN202110134999 A CN 202110134999A CN 112862094 A CN112862094 A CN 112862094A
Authority
CN
China
Prior art keywords
layer
training
network
learning
column
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110134999.9A
Other languages
English (en)
Inventor
张新禹
刘子衿
任祖煜
霍凯
刘振
张双辉
刘永祥
姜卫东
黎湘
卢哲俊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
National University of Defense Technology
Original Assignee
National University of Defense Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by National University of Defense Technology filed Critical National University of Defense Technology
Priority to CN202110134999.9A priority Critical patent/CN112862094A/zh
Publication of CN112862094A publication Critical patent/CN112862094A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F17/00Digital computing or data processing equipment or methods, specially adapted for specific functions
    • G06F17/10Complex mathematical operations
    • G06F17/18Complex mathematical operations for evaluating statistical data, e.g. average values, frequency distributions, probability functions, regression analysis
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Software Systems (AREA)
  • Computational Linguistics (AREA)
  • Computational Mathematics (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Artificial Intelligence (AREA)
  • Health & Medical Sciences (AREA)
  • Pure & Applied Mathematics (AREA)
  • Mathematical Optimization (AREA)
  • General Health & Medical Sciences (AREA)
  • Mathematical Analysis (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Operations Research (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Algebra (AREA)
  • Databases & Information Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明属于机器学习领域,具体涉及一种基于元学习的DRBM方法,通过改进网络的训练‑测试算法,将算法分为元学习和模型学习两个阶段。在元学习阶段利用训练任务更新网络参数,并将更新后的网络参数作为模型学习阶段的网络参数初始值,使网络参数初始值能够使网络训练的损失函数下降更快并且更容易达到全局最优,在模型学习阶段利用测试任务更新网络参数并进行测试。该算法引入元学习的方法对DRBM的训练过程进行改进,使网络参数的元学习阶段梯度下降方向为向“最适应”点下降,使得网络能够快速适应到一个新的任务中。

Description

一种基于元学习的快速适应DRBM方法
技术领域
本发明属于机器学习领域,具体涉及一种基于元学习(Meta Learning)的快速适应判别式受限玻尔兹曼机(Discriminative restricted Boltzmann machine,DRBM)方法。
背景技术
受限玻尔兹曼机(restricted Boltzmann machine,RBM)网络是机器学习中最流行的基础模型之一,也是深度神经网络中最为常用的基本组件。RBM可以利用其隐藏单元进行特征提取和学习数据概率分布并利用学习得到的概率分布生成新的样本,一直受到目标识别和概率模型等领域学者们的广泛研究。DRBM是RBM的一种拓展形式,其核心思想是在一定数量的样本集合中构建一个判别函数,将特征向量和标签一起作为RBM的输入进行训练,使RBM 具有分类的功能。
DRBM最早是由Hugo Larochelle和Yoshua Bengio于2008年提出的(LarochelleH,Bengio Y.Classification using discriminative restricted Boltzmann machines[C]//Proceedings of the 25th international conference on Machinelearning.ACM,2008:536-543.),经过十余年的发展,对其进行过多次优化,已经有较为成熟的网络结构和训练算法。
对于DRBM网络的改进和优化主要可以分为学习算法优化和模型结构优化,元学习的快速适应DRBM网络属于学习算法上的优化。这种元学习方法使得经过DRBM网络训练得到的网络参数θ不再追求在特定训练集上表现最佳,而是追求在所有训练任务中的网络参数初始值θ0都能够只通过几步就快速收敛到最优解。
基于元学习的快速适应DRBM算法的提出受到MAML(Model-Agnostic Meta-Learning, MAML)算法的启发。MAML算法由Chelsea Finn、Pieter Abbeel和Sergey Levine于2017年提出,是一种与模型无关的快速适应元学习算法,适用于任何一种基于梯度下降进行训练的模型,并且适用于各种学习问题,如分类、回归和强化学习。该方法在两种少样本(few-shot) 图像分类数据集(Omniglot和MiniImagenet)上取得了较好的性能,在少样本回归上取得了较好的效果,并利用神经网络策略加速了策略梯度强化学习的微调。
发明内容
本发明的目的是要解决小样本条件下DRBM网络欠拟合以及网络参数初始值无法使网络经过训练后达到全局最优的问题。
本发明的思路是通过改进网络的训练-测试算法,将算法分为元学习和模型学习两个阶段。在元学习阶段利用训练任务更新网络参数,并将更新后的网络参数作为模型学习阶段的网络参数初始值,使网络参数初始值能够使网络训练的损失函数下降更快并且更容易达到全局最优,在模型学习阶段利用测试任务更新网络参数并进行测试。该算法引入元学习的方法对DRBM的训练过程进行改进,使网络参数的元学习阶段梯度下降方向为向“最适应”点下降,使得网络能够快速适应到一个新的任务中。
本发明解决其技术问题所采取的技术方案是:一种基于元学习的快速适应DRBM方法,分为以下步骤:
S1.建立DRBM网络结构。DRBM的网络结构可以分为三层——可见层、隐藏层和分类层,每层包含若干个神经元,神经元的连接方式是同一层内部的节点之间没有任何连接,而层与层之间的节点互相以全连接的方式相互连接在一起。每个神经元的状态均为1或者0的二元取值,1表示激活,0表示未激活,激活意味着该神经元所代表的节点对数据进行了处理。 DRBM的分布由神经元的值确定,其中可见层用于表示输入数据,可见层节点个数由输入数据维度决定,可见层节点取值为输入数据各维的取值;隐藏层按照最优化的方式获取观察数据的某种统计意义上的特征,隐藏层节点个数根据数据和任务不同人为进行调整;分类层单元根据隐藏层单元提取出的数据特征进行类别判定,分类层节点个数由数据类别数量决定。
DRBM网络由网络参数进行描述。假设可见层节点个数为l,隐藏层节点个数为m,分类层节点个数为n,可见层偏置向量为b,b为1行l列向量、隐藏层偏置向量为c,c为1 行m列向量、分类层偏置向量为d,d为1行n列向量,输入层和隐藏层的权重矩阵为W,W 为l行m列矩阵、分类层和隐藏层的权重矩阵为U,U为n行m列矩阵。设向量θ=(W,U,b,c,d),则训练DRBM网络的目的就是寻找最佳的θ值,来通过网络预测数据类别。
DRBM是一种基于能量函数确定的模型,其能量函数可以被定义为:
E(y,x,h)=-hWTxT-bxT-chT-dey T-hUTyT (1)
其中x表示可见层的状态向量,x为1行l列向量、h表示隐藏层单元的状态向量,h为1行m 列向量、y表示分类层的状态向量,y为1行n列向量,y是标签的“独热”(one-hot)型表示,即所有节点中只有一个节点为1,其余节点均为0。x、h、y的联合概率分布为:
Figure RE-GDA0003007253060000021
其中
Figure RE-GDA0003007253060000022
称为配分函数。
S2元学习阶段:
S2.1网络参数初始化:
初始化可见层的偏置向量b为1行l列的零矩阵、隐藏层的偏置向量c为1行m列的零矩阵、分类层的偏置向量d为1行n列的零矩阵,以及对应的梯度Δb为1行l列的零矩阵、Δc为 1行m列的零矩阵、Δd为1行n列的零矩阵,可见层与隐藏层的偏置矩阵W为l行m列的零矩阵和分类层与隐藏层的偏置矩阵U为n行m列的零矩阵,以及对应的梯度ΔW为l行m列的零矩阵、ΔU为n行m列的零矩阵,网络参数θ初始值记作θ0;设置内部学习率α为0.01~0.5、外部学习率β为0.005~0.05、动量(momentum)学习率m=0.5和惩罚(penalty)系数p=10-4
S2.2完成一个任务(task)的训练:
元学习阶段的训练以一个任务为基本单元,每个任务包含两个部分——支撑集(support set)和质询集(query set)。利用支撑集进行训练的过程称为内部学习,利用质询集进行训练的过程称为外部学习。每个任务的数据类别与其他任务可以相同也可以不同。元学习阶段的所有任务共同组成了训练任务(training tasks)。训练任务中的所有样本都既含有数据信息也含有标签信息。具体步骤如下:
S2.2.1将支撑集作为网络输入,θ0作为网络参数初始值,完成一次训练:
S2.2.1.1计算隐藏层的概率分布函数:
p(h|y,x)=sigmoid(x(0)W+y(0)U+c) (3)
其中x(0)为输入数据,y(0)为独热形式的输入数据标签,
Figure RE-GDA0003007253060000031
S2.2.1.2得到隐藏层概率分布函数后,利用Gibbs采样得到隐藏层节点取值。Gibbs采样的具体方法为:产生m个[0,1]上的随机数ri,i∈[1,m]为隐藏层节点序号,若p(hi|y,x)>ri,则节点hi的值为1,否则为0,得到的隐藏层节点分布记作h(0),从概率分布p(h|y,x)中采样出h(0)的Gibbs采样过程记为h(0)~p(h|y,x)。
S2.2.1.3根据隐藏层节点取值重构可见层和分类层,分别计算可见层和分类层的概率分布函数:
Figure RE-GDA0003007253060000032
其中
Figure RE-GDA0003007253060000033
c′表示c的所有可能取值。
S2.2.1.4求得可见层和分类层概率分布函数后,利用Gibbs采样x(1)~p(x|h)和 y(1)~p(y|h)得到可见层和分类层的节点取值x(1)、y(1)
S2.2.1.5根据可见层和分类层的节点取值x(1)、y(1)再次计算隐藏层概率分布函数:
p(h|y,x)=sigmoid(x(1)W+y(1)U+c) (5)
并通过Gibbs采样得到h(1)~p(h|y,x)。
S2.2.1.6根据x(0)、y(0)、x(1)、y(1)、h(0)、h(1),求得网络参数θ的更新梯度:
Figure RE-GDA0003007253060000034
S2.2.1.7输入训练任务集中的任务,根据内部学习率α、动量学习率m和惩罚系数p对梯度进行修正:
Figure RE-GDA0003007253060000041
其中i∈[1,ns],ns为支撑集中的样本数量;
S2.2.1.8根据梯度对网络参数θ进行更新:
Figure RE-GDA0003007253060000042
输入支撑集数据进行训练更新后的网络参数记为θns
S2.2.2将质询集作为网络输入,θs作为网络参数初始值,完成一次训练:
按照公式(3)、(4)、(5)、(6),依次完成概率分布函数计算、节点分布采样以及网络参数更新梯度计算,根据公式(9)对修正梯度进行计算:
Figure RE-GDA0003007253060000043
其中β为外部学习率,i∈[1,nq],nq为质询集中的样本数量。最后根据公式(8)完成网络参数更新,得到的网络参数记为θnq
在完成一个任务的训练后,对网络参数进行一次更新,并且更新只保留外部学习部分,即:
θt+1=θt+(θnqns) (10)
其中,t∈[1,nt],t表示第t个训练任务,nt表示训练任务数量,保存网络参数θt+1,作为下一个任务的网络参数初始值,一个任务的训练结束。
S2.3完成所有遍历(epoch):
每一次遍历需要训练若干个任务,任务个数根据数据集大小决定,通常可以设置20~100 组任务。当完成一个任务的训练后,将更新后的网络参数初始值作为下一个任务的网络参数初始值,依次重复S2.2的过程,直至所有任务完成一次训练,这就是一次遍历。完成一次遍历后,将更新后的网络参数作为下一次遍历的网络参数初始值,直至完成所有遍历。最终得到的网络参数记为θnt
元学习阶段通常需要完成多次遍历,遍历次数取决于网络的收敛速度快慢,通常设定的遍历次数为50次。
S3模型学习阶段:
与元学习阶段相似,模型学习阶段的任务也包含支撑集和质询集,但是模型学习阶段通常只有一个任务,称为测试任务,测试任务中的数据类别通常与训练任务不同。
S3.1将支撑集作为网络输入,θnt作为网络参数初始值,完成一次训练:
按照公式(3)、(4)、(5)、(6),依次完成概率分布函数计算、节点分布采样以及网络参数更新梯度的计算,再按照公式(11)对修正梯度进行计算:
Figure RE-GDA0003007253060000051
其中i∈[1,ts],ts为测试任务的支撑集中的样本数量,最后根据公式(8)完成网络参数更新。
S3.2完成所有遍历:
每一次遍历就是对测试任务中的支撑集完成一次训练,完成一次遍历后,将更新后的网络参数作为下一次遍历的网络参数初始值,直至完成所有遍历,最终得到的网络参数记为θt
模型学习阶段通常需要完成50~150次遍历,遍历次数取决于网络的收敛速度快慢。
S3.3将质询集中的数据输入网络,θt作为网络参数,依次计算出第i个类别的预测概率:
prediction(i)=repeat(d(i),tq)+log(exp(x(0)·W+T(i)·U+c)+1) (12)
其中T为tq行,nc列的矩阵,tq为质询集样本数量,nc为总类别数,T(i)表示偏置向量T除第i列为1外其余列全为0;d(i)表示偏置向量d除第i列为1外其余列全为0; repeat(d(i),tq)表示将偏置向量d(i)重复tq次,变为tq行,n列的矩阵。
计算完nc个类别的预测概率后,取其中最大值所在列,作为类别预测结果,完成目标分类。
通过相关实验验证,本发明取得的有益效果为:
(1)经过元学习阶段的训练后,能够使网络参数初始值更接近“最速收敛点”,只需要输入少量样本进行网络的训练微调,就能够很好地拟合测试样本,并且能够使网络在训练微调后更容易达到全局最优;
(2)能够增强网络的特征表达能力,使用多类任务对网络进行训练时更不易欠拟合;
(3)利用本算法能够使网络参数快速收敛的特点,可以应用到小样本条件下,以提升网络的识别正确率;
(4)不仅可以作为一种训练-测试方法(测试任务与训练任务类别不同),也可以作为一种预训练算法(测试任务与训练任务类别相同)。
附图说明
图1本发明的网络结构图;
图2本发明的算法流程图;
图3本发明对于高分辨距离像(high resolution range profile,HRRP)数据的识别流程;
图4HRRP数据;
图5HRRP数据处理结果;
图6HRRP数据集处理结果;
图7本发明与传统算法的实验结果对比;
图8MNIST数据集的样本;
图9本发明与传统算法的实验结果对比。
具体实施方式
下面结合附图对本发明进行进一步说明:
本发明的实例1展示了所提出算法的对于HRRP数据的完整识别流程以及与传统算法的识别结果对比。实例2展示了所提出算法与传统算法对于MNIST数据集的识别结果对比。
图1为本发明的网络结构,网络分为可见层(v层)、标签层(y层)和隐藏层(h层)。 v=(v1,v2,…,vl)、h=(h1,h2,…,hm)和y∈{0,1}n分别为可见层、隐藏层和标签层的状态向量,标签层向量y∈{0,1}n为独热形式。在本实例中,可见层、隐藏层和标签层节点数分别为201、 130和3个。
图2为本发明的算法流程图,展示了标准的基于元学习的快速适应DRBM算法的流程。
图3为本发明对于HRRP数据的识别流程,对比图2,要多出数据处理阶段。这是由于DRBM网络只能输入二值化的数据,所以获得HRRP数据后首先要进行数据预处理,包括四个步骤:第一步是数据分选,即根据实验不同,挑选不同俯仰角和方位角的数据,再将其顺序进行随机排列;第二步是幅度归一化克服信号的距离敏感性,即按照数据集中幅度最大的样本进行幅度归一化;第三步是均值归一化,以增强网络对不同特征的学习能力。HRRP数据均值归一化的意义在于两个方面:1)减去均值后,剩余部分可以直观地理解为各样本的差值部分;2)在更新网络参数的过程中,算法振荡较小,更容易收敛;第四步是数据二值化,根据幅度转化为概率进行采样得到二值化的HRRP数据。
实例1中使用赛博公司设计的三维飞机模型电磁仿真仿真软件输出的HRRP仿真数据,仿真的三类飞机型号为F-35、F-117和P-51,飞机的具体参数如图4所示。
其中雷达仿真波段为x波段。频率范围9.5GHz-10.5GHz,步长为5MHz。极化方式模式为垂直极化。目标俯仰角为0°~10°,步长为0.1°,方位角为0°~90°,步长为0.1°。因此我们的数据集总共有201个频率点,101个俯仰角和901个方位角,即共有101×901=91001个样本,每个样本为201维。
图5展示了HRRP数据在预处理过程中的变化,其中x轴代表数据维度,y轴代表信号幅度。(a)为预处理前的HRRP数据,(b)为归一化后的HRRP数据,(c)为二值化后的 HRRP数据。
图6展示了一个由8100个HRRP数据组成的数据集,其中x轴代表数据维度,y轴代表样本数,z轴代表信号幅度,第1~2700个样本为F-35飞机HRRP,第2701~5400个样本为 F-117飞机HRRP,第5401~8100个样本为P-51飞机HRRP。(a)为预处理前的HRRP数据, (b)为处理之后的HRRP数据,可以看到经过处理后的HRRP数据满足0-1分布且分布概率与原幅度强度相同。
在实例1中,我们从仿真数据中选取每类飞机方位角为0°~10°和80°~90°,俯仰角为 3°~5°,步长均为0.1°的数据,每类飞机(101+101)*21=4242个样本,共计12726个样本,其中方位角为0.2°的倍数的作为测试集
Figure RE-GDA0003007253060000071
共6363个样本,其余作为训练集
Figure RE-GDA0003007253060000078
共6363个样本。
在数据处理完毕后,将进入算法的训练和测试阶段,在实例1中,我们进行两组实验:
实验1:使用传统算法对DRBM进行训练,从
Figure RE-GDA0003007253060000073
中随机抽取三类目标各n个样本并进行训练,从
Figure RE-GDA0003007253060000074
中随机抽取三类目标各2000个样本并进行测试;
实验2:使用本发明的算法对DRBM进行训练和测试,从
Figure RE-GDA0003007253060000075
中随机抽取三类目标各n 个样本并进行训练,其中每个任务中支撑集占样本数量的1/4,质询集占样本数量的3/4;从
Figure RE-GDA0003007253060000076
中剩余的数据中随机抽取三类目标各n个样本作为测试集中的支撑集并进行训练,从
Figure RE-GDA0003007253060000077
中随机抽取三类目标各2000个样本作为质询集并进行测试;
各类样本个数n的取值范围是[20,400],每组实验都遍历50次(epoch=50)。
实验2具体训练和测试的步骤为:
S1元学习阶段:
初始化参数:可见层偏置向量b=(b1,b2,…,b201)、隐藏层偏置向量c=(c1,c2,…,c130)和标签层的偏置向量d=(d1,d2,d3)分别为相应维度的零向量,可见层与隐藏层之间的权值矩阵W=(wi,j)∈R201×130、标签层与隐藏层之间的权值矩阵U=(ui,j)∈R3×130分别为相应维度的零矩阵,各网络参数对应的梯度均为相应维度的零向量。内部学习率α=0.1、外部学习率β=0.001、动量学习率m=0.5,惩罚系数p=10-4
Figure RE-GDA0003007253060000081
中随机抽取三类目标各n个样本的1/4作为支撑集,3/4作为质询集,按照S1.2和S1.3的步骤依次完成50次遍历的训练,得到训练后的网络参数初始值θ0
S2模型学习阶段:
将从
Figure RE-GDA0003007253060000082
中剩余的数据中随机抽取三类目标各n个样本作为支撑集并进行训练,按照 S2.1和S2.2步骤完成网络参数θ的微调,训练完成后的网络参数记为θs
将从
Figure RE-GDA0003007253060000083
中随机抽取三类目标各2000个样本作为质询集并进行测试。
图7为两组实验的结果,图中圆圈和方框为n取不同值时重复20次实验得到的分类正确率均值,每个节点上的条带为50次实验的分类正确率区间。从图中可以得出如下结论:
1)在实验2中,因为训练集于测试集的数据类别相同,所以本发明被作为一种预训练的方法来使用。可以看出,当n取不同值时,实验2的分类正确率都要高于实验1,这说明元学习得到的网络参数初始值比随机化初始值使网络更容易达到全局最优;
2)在小样本条件下(n≤200),本发明对于分类性能的提升尤为明显,这说明元学习得到的网络参数初始值比随机化初始值收敛更快;
3)当n取不同值时,实验2的分类正确率区间都小于实验1,说明本发明训练所得模型的稳定性更强。
实例2中使用手写体数字(MNIST)数据集进行实验。MNIST数据集由7万张图片及对应标签组成,其中6万张用于训练神经网络,1万张用于测试神经网络。每张图片是一个28*28 像素点的0~9的手写数字图片。图片为黑底白字,黑色用0表示,白色用0~1之间的浮点数表示,越接近1,颜色越白。MNIST数据集提供了每张图片对应的标签,以独热形式给出,即标签向量为一个长度为10的一维数组。MNIST数据集的样本如图8所示。
我们把784个像素点组成一个长度为784的一维数组,这个一维数组就是我们要输入神经网络的输入数据。对于RBM网络来说,网络的输入数据只能是二值的,所以需要对MNIST 数据集进行预处理,我们根据每个像素点的黑白明暗进行随机二值化处理,即根据每个像素点的白色强度值转化为概率进行采样,使其满足二值分布。
我们共进行三组实验:
实验1:使用传统算法对DRBM进行训练,训练0~9的数字各n个,测试8~9的数字各5000个;
实验2:使用即传统算法对DRBM进行训练,训练8~9的数字各n个,测试8~9的数字各5000个;
实验3:使用本发明所提出的算法对DRBM进行训练和测试,训练阶段训练任务集包括 0~7的数字各n个,其中每个任务中支撑集占样本数量的1/4,质询集占样本数量的3/4;测试阶段先用支撑集中的8~9的数字各n个进行参数微调,再用质询集中8~9的数字各5000 个进行测试。
实例2的具体训练和测试的步骤与实例1相似,区别在于网络节点数与学习率的设置有所不同。各类样本个数n的取值范围是[5,1000],每种实验重复了20次,根据其分类正确率的区间进行绘图,结果如图9所示:
图中各节点为n取不同值时重复20次实验得到的分类正确率均值,每个节点上的条带为 20次实验的分类正确率区间。
从图中可以得出如下结论:
处于小样本条件下(n≤100)时:
对比实验1和实验2,当训练数据的类别多于测试数据的类别时,网络对于测试数据的分类正确率会下降。这是由于网络结构限制导致特征表达能力有限,训练任务类别越多,网络越容易陷入欠拟合,使网络对于单一任务的识别能力下降。
对比实验1和实验3,虽然参与网络训练的样本类别和个数相同,但是实验3在n取不同值时,分类正确率都要高于实验1,尤其是在小样本条件下。这是因为所提算法算法在训练阶段更趋向于寻找“最速收敛点”,而不是去逼近最优点。所以即使输入与测试数据类别不同的数字来进行网络训练,依然可以学到良好的初始参数。并且当n取不同值时,实验3的分类正确率区间都小于实验1,说明所提出算法训练所得模型的稳定性更强。
对比实验2和实验3,实验3在n取不同值时,分类正确率都要高于实验2,这说明所提算法并没有因为训练类别过多而降低了网络的特征表达能力,也说明实验3经过训练阶段学习到的网络参数初始值兼顾了不同任务上的学习能力,只需要进行少量的训练,就可以使网络更好地拟合训练数据。
当样本数量充足(n=1000)时,实验3的分类正确率要高于其他两种,这说明通过元学习方法,网络参数达到了全局最优点。
综上所述,本发明能够在小样本条件下有效地提高DRBM分类正确率,有较高工程应用价值。

Claims (5)

1.一种基于元学习的快速适应DRBM方法,其特征在于,该方法分为以下步骤:
S1.建立DRBM网络结构:DRBM的网络结构可以分为三层——可见层、隐藏层和分类层,每层包含若干个神经元,神经元的连接方式是同一层内部的节点之间没有任何连接,而层与层之间的节点互相以全连接的方式相互连接在一起;每个神经元的状态均为1或者0的二元取值,1表示激活,0表示未激活,激活意味着该神经元所代表的节点对数据进行了处理;DRBM的分布由神经元的值确定,其中可见层用于表示输入数据,可见层节点个数由输入数据维度决定,可见层节点取值为输入数据各维的取值;隐藏层按照最优化的方式获取观察数据的某种统计意义上的特征,隐藏层节点个数根据数据和任务不同人为进行调整;分类层单元根据隐藏层单元提取出的数据特征进行类别判定,分类层节点个数由数据类别数量决定;
DRBM网络由网络参数进行描述;假设可见层节点个数为l,隐藏层节点个数为m,分类层节点个数为n,可见层偏置向量为b,b为1行l列向量、隐藏层偏置向量为c,c为1行m列向量、分类层偏置向量为d,d为1行n列向量,输入层和隐藏层的权重矩阵为W,W为l行m列矩阵、分类层和隐藏层的权重矩阵为U,U为n行m列矩阵;设向量θ=(W,U,b,c,d),则训练DRBM网络的目的就是寻找最佳的θ值,来通过网络预测数据类别;
DRBM是一种基于能量函数确定的模型,其能量函数可以被定义为:
E(y,x,h)=-hWTxT-bxT-chT-dey T-hUTyT (1)
其中x表示可见层的状态向量,x为1行l列向量、h表示隐藏层单元的状态向量,h为1行m列向量、y表示分类层的状态向量,y为1行n列向量,y是标签的“独热”型表示,即所有节点中只有一个节点为1,其余节点均为0;x、h、y的联合概率分布为:
Figure FDA0002923473620000011
其中
Figure FDA0002923473620000012
称为配分函数;
S2元学习阶段:
S2.1网络参数初始化:
初始化可见层的偏置向量b为1行l列的零矩阵、隐藏层的偏置向量c为1行m列的零矩阵、分类层的偏置向量d为1行n列的零矩阵,以及对应的梯度Δb为1行l列的零矩阵、Δc为1行m列的零矩阵、Δd为1行n列的零矩阵,可见层与隐藏层的偏置矩阵W为l行m列的零矩阵和分类层与隐藏层的偏置矩阵U为n行m列的零矩阵,以及对应的梯度ΔW为l行m列的零矩阵、ΔU为n行m列的零矩阵,网络参数θ初始值记作θ0;设置内部学习率α为0.01~0.5、外部学习率β为0.005~0.05、动量学习率m=0.5和惩罚系数p=10-4
S2.2完成一个任务的训练:
元学习阶段的训练以一个任务为基本单元,每个任务包含两个部分——支撑集和质询集;利用支撑集进行训练的过程称为内部学习,利用质询集进行训练的过程称为外部学习;每个任务的数据类别与其他任务可以相同也可以不同;元学习阶段的所有任务共同组成了训练任务;训练任务中的所有样本都既含有数据信息也含有标签信息;具体步骤如下:
S2.2.1将支撑集作为网络输入,θ0作为网络参数初始值,完成一次训练:
S2.2.1.1计算隐藏层的概率分布函数:
p(h|y,x)=sigmoid(x(0)W+y(0)U+c) (3)
其中x(0)为输入数据,y(0)为独热形式的输入数据标签,
Figure FDA0002923473620000021
S2.2.1.2得到隐藏层概率分布函数后,利用Gibbs采样得到隐藏层节点取值;
S2.2.1.3根据隐藏层节点取值重构可见层和分类层,分别计算可见层和分类层的概率分布函数:
Figure FDA0002923473620000022
其中
Figure FDA0002923473620000023
c′表示c的所有可能取值;
S2.2.1.4求得可见层和分类层概率分布函数后,利用Gibbs采样x(1)~p(x|h)和y(1)~p(y|h)得到可见层和分类层的节点取值x(1)、y(1)
S2.2.1.5根据可见层和分类层的节点取值x(1)、y(1)再次计算隐藏层概率分布函数:
p(h|y,x)=sigmoid(x(1)W+y(1)U+c) (5)
并通过Gibbs采样得到h(1)~p(h|y,x);
S2.2.1.6根据x(0)、y(0)、x(1)、y(1)、h(0)、h(1),求得网络参数θ的更新梯度:
Figure FDA0002923473620000024
S2.2.1.7输入训练任务集中的任务,根据内部学习率α、动量学习率m和惩罚系数p对梯度进行修正:
Figure FDA0002923473620000031
其中i∈[1,ns],ns为支撑集中的样本数量;
S2.2.1.8根据梯度对网络参数θ进行更新:
Figure FDA0002923473620000032
输入支撑集数据进行训练更新后的网络参数记为θns
S2.2.2将质询集作为网络输入,θs作为网络参数初始值,完成一次训练:
按照公式(3)、(4)、(5)、(6),依次完成概率分布函数计算、节点分布采样以及网络参数更新梯度计算,根据公式(9)对修正梯度进行计算:
Figure FDA0002923473620000033
其中β为外部学习率,i∈[1,nq],nq为质询集中的样本数量;最后根据公式(8)完成网络参数更新,得到的网络参数记为θnq
在完成一个任务的训练后,对网络参数进行一次更新,并且更新只保留外部学习部分,即:
θt+1=θt+(θnqns) (10)
其中,t∈[1,nt],t表示第t个训练任务,nt表示训练任务数量,保存网络参数θt+1,作为下一个任务的网络参数初始值,一个任务的训练结束;
S2.3完成所有遍历:
每一次遍历需要训练若干个任务,任务个数根据数据集大小决定;当完成一个任务的训练后,将更新后的网络参数初始值作为下一个任务的网络参数初始值,依次重复S2.2的过程,直至所有任务完成一次训练,这就是一次遍历;完成一次遍历后,将更新后的网络参数作为下一次遍历的网络参数初始值,直至完成所有遍历,最终得到的网络参数记为θnt
S3模型学习阶段:
S3.1将支撑集作为网络输入,θnt作为网络参数初始值,完成一次训练:
按照公式(3)、(4)、(5)、(6),依次完成概率分布函数计算、节点分布采样以及网络参数更新梯度的计算,再按照公式(11)对修正梯度进行计算:
Figure FDA0002923473620000041
其中i∈[1,ts],ts为测试任务的支撑集中的样本数量,最后根据公式(8)完成网络参数更新;
S3.2完成所有遍历:
每一次遍历就是对测试任务中的支撑集完成一次训练,完成一次遍历后,将更新后的网络参数作为下一次遍历的网络参数初始值,直至完成所有遍历,最终得到的网络参数记为θt
S3.3将质询集中的数据输入网络,θt作为网络参数,依次计算出第i个类别的预测概率:
prediction(i)=repeat(d(i),tq)+log(exp(x(0)·W+T(i)·U+c)+1) (12)
其中T为tq行,nc列的矩阵,tq为质询集样本数量,nc为总类别数,T(i)表示偏置向量T除第i列为1外其余列全为0;d(i)表示偏置向量d除第i列为1外其余列全为0;repeat(d(i),tq)表示将偏置向量d(i)重复tq次,变为tq行,n列的矩阵;
计算完nc个类别的预测概率后,取其中最大值所在列,作为类别预测结果,完成目标分类。
2.一种根据权利要求1所述基于元学习的快速适应DRBM方法,其特征在于:S2.2.1.2中,Gibbs采样的具体方法为:产生m个[0,1]上的随机数ri,i∈[1,m]为隐藏层节点序号,若p(hi|y,x)>ri,则节点hi的值为1,否则为0,得到的隐藏层节点分布记作h(0),从概率分布p(h|y,x)中采样出h(0)的Gibbs采样过程记为h(0)~p(h|y,x)。
3.一种根据权利要求1所述基于元学习的快速适应DRBM方法,其特征在于:S2.3中,设置20~100组任务。
4.一种根据权利要求1所述基于元学习的快速适应DRBM方法,其特征在于:元学习阶段设定的遍历次数为50次。
5.一种根据权利要求1所述基于元学习的快速适应DRBM方法,其特征在于:模型学习阶段设定的遍历次数为50~150次。
CN202110134999.9A 2021-01-29 2021-01-29 一种基于元学习的快速适应drbm方法 Pending CN112862094A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110134999.9A CN112862094A (zh) 2021-01-29 2021-01-29 一种基于元学习的快速适应drbm方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110134999.9A CN112862094A (zh) 2021-01-29 2021-01-29 一种基于元学习的快速适应drbm方法

Publications (1)

Publication Number Publication Date
CN112862094A true CN112862094A (zh) 2021-05-28

Family

ID=75987291

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110134999.9A Pending CN112862094A (zh) 2021-01-29 2021-01-29 一种基于元学习的快速适应drbm方法

Country Status (1)

Country Link
CN (1) CN112862094A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115114844A (zh) * 2022-05-09 2022-09-27 东南大学 一种钢筋混凝土粘结滑移曲线的元学习预测模型
CN116737939A (zh) * 2023-08-09 2023-09-12 恒生电子股份有限公司 元学习方法、文本分类方法、装置、电子设备及存储介质

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115114844A (zh) * 2022-05-09 2022-09-27 东南大学 一种钢筋混凝土粘结滑移曲线的元学习预测模型
CN115114844B (zh) * 2022-05-09 2023-09-19 东南大学 一种钢筋混凝土粘结滑移曲线的元学习预测模型
CN116737939A (zh) * 2023-08-09 2023-09-12 恒生电子股份有限公司 元学习方法、文本分类方法、装置、电子设备及存储介质
CN116737939B (zh) * 2023-08-09 2023-11-03 恒生电子股份有限公司 元学习方法、文本分类方法、装置、电子设备及存储介质

Similar Documents

Publication Publication Date Title
CN110020682B (zh) 一种基于小样本学习的注意力机制关系对比网络模型方法
CN110516596B (zh) 基于Octave卷积的空谱注意力高光谱图像分类方法
CN109063724B (zh) 一种增强型生成式对抗网络以及目标样本识别方法
CN111898730A (zh) 一种利用图卷积神经网络结构加速的结构优化设计方法
CN113222011B (zh) 一种基于原型校正的小样本遥感图像分类方法
CN109146000B (zh) 一种基于冰冻权值改进卷积神经网络的方法及装置
CN111914728B (zh) 高光谱遥感影像半监督分类方法、装置及存储介质
CN109740679B (zh) 一种基于卷积神经网络和朴素贝叶斯的目标识别方法
CN112884059B (zh) 一种融合先验知识的小样本雷达工作模式分类方法
CN105528638A (zh) 灰色关联分析法确定卷积神经网络隐层特征图个数的方法
CN112862094A (zh) 一种基于元学习的快速适应drbm方法
CN111832580B (zh) 结合少样本学习与目标属性特征的sar目标识别方法
CN110110845B (zh) 一种基于并行多级宽度神经网络的学习方法
CN111311702B (zh) 一种基于BlockGAN的图像生成和识别模块及方法
Minh et al. Automated image data preprocessing with deep reinforcement learning
CN111767860A (zh) 一种通过卷积神经网络实现图像识别的方法及终端
CN113987236B (zh) 基于图卷积网络的视觉检索模型的无监督训练方法和装置
Puri et al. Few shot learning for point cloud data using model agnostic meta learning
CN111310791A (zh) 一种基于小样本数目集的动态渐进式自动目标识别方法
Bianchi et al. Improving image classification robustness through selective cnn-filters fine-tuning
JP2000259766A (ja) パターン認識方法
CN109948589A (zh) 基于量子深度信念网络的人脸表情识别方法
CN111783688A (zh) 一种基于卷积神经网络的遥感图像场景分类方法
Sufikarimi et al. Speed up biological inspired object recognition, HMAX
CN114037866A (zh) 一种基于可辨伪特征合成的广义零样本图像分类方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
WD01 Invention patent application deemed withdrawn after publication
WD01 Invention patent application deemed withdrawn after publication

Application publication date: 20210528