CN112365385A - 基于自注意力的知识蒸馏方法、装置和计算机设备 - Google Patents

基于自注意力的知识蒸馏方法、装置和计算机设备 Download PDF

Info

Publication number
CN112365385A
CN112365385A CN202110059942.7A CN202110059942A CN112365385A CN 112365385 A CN112365385 A CN 112365385A CN 202110059942 A CN202110059942 A CN 202110059942A CN 112365385 A CN112365385 A CN 112365385A
Authority
CN
China
Prior art keywords
model
self
matrix
weight distribution
feature matrix
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202110059942.7A
Other languages
English (en)
Other versions
CN112365385B (zh
Inventor
徐泓洋
王广新
杨汉丹
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Youjie Zhixin Technology Co ltd
Original Assignee
Shenzhen Youjie Zhixin Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Youjie Zhixin Technology Co ltd filed Critical Shenzhen Youjie Zhixin Technology Co ltd
Priority to CN202110059942.7A priority Critical patent/CN112365385B/zh
Publication of CN112365385A publication Critical patent/CN112365385A/zh
Application granted granted Critical
Publication of CN112365385B publication Critical patent/CN112365385B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q50/00Information and communication technology [ICT] specially adapted for implementation of business processes of specific business sectors, e.g. utilities or tourism
    • G06Q50/10Services
    • G06Q50/20Education
    • G06Q50/205Education administration or guidance
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F17/00Digital computing or data processing equipment or methods, specially adapted for specific functions
    • G06F17/10Complex mathematical operations
    • G06F17/16Matrix or vector computation, e.g. matrix-matrix or matrix-vector multiplication, matrix factorization
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q10/00Administration; Management
    • G06Q10/06Resources, workflows, human or project management; Enterprise or organisation planning; Enterprise or organisation modelling
    • G06Q10/067Enterprise or organisation modelling

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Business, Economics & Management (AREA)
  • Theoretical Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Strategic Management (AREA)
  • Mathematical Physics (AREA)
  • Data Mining & Analysis (AREA)
  • Software Systems (AREA)
  • Human Resources & Organizations (AREA)
  • Economics (AREA)
  • Educational Administration (AREA)
  • Tourism & Hospitality (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Educational Technology (AREA)
  • General Business, Economics & Management (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Entrepreneurship & Innovation (AREA)
  • Pure & Applied Mathematics (AREA)
  • Mathematical Optimization (AREA)
  • Artificial Intelligence (AREA)
  • Mathematical Analysis (AREA)
  • Health & Medical Sciences (AREA)
  • Marketing (AREA)
  • Computational Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Algebra (AREA)
  • Medical Informatics (AREA)
  • Primary Health Care (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Molecular Biology (AREA)
  • Databases & Information Systems (AREA)
  • Development Economics (AREA)
  • Computational Linguistics (AREA)
  • Game Theory and Decision Science (AREA)
  • Operations Research (AREA)

Abstract

本申请涉及人工智能领域,揭示了基于自注意力的知识蒸馏方法,包括:将输入数据输入第一模型得到第一模型的中间层输出的第一特征矩阵,将输入数据输入第二模型得到第二模型的中间层输出的第二特征矩阵,其中,第一模型为训练好的老师模型,第二模型为待训练的学生模型;根据第一特征矩阵计算老师模型对应的第一自注意力权重分布,根据第二特征矩阵计算学生模型对应的第二自注意力权重分布;计算第一自注意力权重分布和第二自注意力权重分布之间的分布差异;将分布差异,作为老师模型和学生模型之间的知识蒸馏损失函数;根据知识蒸馏损失函数,将老师模型的中间层的数据映射关系迁移至学生模型的中间层上,能满足不同任务类型模型的知识蒸馏训练。

Description

基于自注意力的知识蒸馏方法、装置和计算机设备
技术领域
本申请涉及人工智能领域,特别是涉及到基于自注意力的知识蒸馏方法、装置和计算机设备。
背景技术
知识蒸馏(Knowledge Distillation)是一种特殊的迁移学习方法,目的主要是在保证训练效果的同时对待训练模型的体积进行压缩。通过训练好的老师模型来指导小体积的待训练学生模型的学习,通过训练小模型学习到大模型的知识,相比于直接训练小模型效果更好,速度更快。
目前用于知识蒸馏的损失函数更多的是针对分类模型进行的,其要求大模型和小模型的类别数或者网络输出特征维度要一致,限制了知识蒸馏的应用范围,不能满足不同任务类型模型的知识蒸馏训练。
发明内容
本申请的主要目的为提供基于自注意力的知识蒸馏方法,旨在解决现有知识蒸馏的损失函数设计,不能满足不同任务类型模型的知识蒸馏训练的技术问题。
本申请提出一种基于自注意力的知识蒸馏方法,包括:
将输入数据输入第一模型得到所述第一模型的中间层输出的第一特征矩阵,将所述输入数据输入第二模型得到所述第二模型的中间层输出的第二特征矩阵,其中,所述第一模型为训练好的老师模型,所述第二模型为待训练的学生模型,所述第一特征矩阵和所述第二特征矩阵具有相同的序列长度;
根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布,根据所述第二特征矩阵计算所述学生模型对应的第二自注意力权重分布;
计算所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异;
将所述分布差异,作为所述老师模型和所述学生模型之间的知识蒸馏损失函数;
根据所述知识蒸馏损失函数,将所述老师模型的中间层的数据映射关系迁移至所述学生模型的中间层上。
优选地,所述根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布的步骤,包括:
将所述第一特征矩阵进行转置计算,得到所述第一特征矩阵对应的第一转置矩阵;
根据所述第一特征矩阵和所述第一转置矩阵,计算所述第一特征矩阵的内部特征关系;
对所述第一特征矩阵的内部特征关系进行softmax函数计算,得到所述老师模型对应的第一自注意力权重分布。
优选地,所述老师模型使用多头注意力机制,所述根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布的步骤,包括:
将所述第一特征矩阵按照所述多头注意力机制对应的头数进行均分,得到多个子矩阵;
对第一子矩阵进行转置计算,得到所述第一子矩阵对应的第一转置子矩阵,其中,所述第一子矩阵为所述第一特征矩阵对应的多个子矩阵中的任一个;
根据所述第一子矩阵和所述第一转置子矩阵,计算所述第一子矩阵的内部特征关系;
根据所述第一子矩阵的内部特征关系的计算方式,计算所述第一特征矩阵的各子矩阵分别对应的内部特征关系;
将所述第一特征矩阵的各子矩阵分别对应的内部特征关系,按照各子矩阵在所述第一特征矩阵中的排布次序,拼接成所述第一特征矩阵的内部特征关系;
对所述第一特征矩阵的内部特征关系进行softmax函数计算,得到所述老师模型对应的第一自注意力权重分布。
优选地,所述计算所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异的步骤,包括:
计算所述第一自注意力权重分布和所述第一自注意力权重分布之间的KL散度损失;
将所述KL散度损失作为所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异。
优选地,所述根据所述知识蒸馏损失函数,将所述老师模型的中间层的数据映射关系迁移至所述学生模型的中间层上的步骤之后,包括:
获取预设的所述学生模型的任务类型;
根据所述学生模型的任务类型匹配全连接层和目标函数,其中,所述全连接层连接于所述学生模型的中间层的输出端;
根据所述知识蒸馏损失函数和所述目标函数,形成训练所述学生模型的总损失函数;
根据所述总损失函数在训练集上训练所述学生模型。
优选地,所述根据所述知识蒸馏损失函数和所述目标函数,形成训练所述学生模型的总损失函数的步骤,包括:
获取所述知识蒸馏损失函数和所述目标函数的数量级别差;
根据所述数量级别差确定调节参数;
根据所述调节参数、所述知识蒸馏损失函数和所述目标函数,训练所述学生模型的总损失函数。
本申请还提供了一种基于自注意力的知识蒸馏装置,包括:
输入模块,用于将输入数据输入第一模型得到所述第一模型的中间层输出的第一特征矩阵,将所述输入数据输入第二模型得到所述第二模型的中间层输出的第二特征矩阵,其中,所述第一模型为训练好的老师模型,所述第二模型为待训练的学生模型,所述第一特征矩阵和所述第二特征矩阵具有相同的序列长度;
第一计算模块,用于根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布,根据所述第二特征矩阵计算所述学生模型对应的第二自注意力权重分布;
第二计算模块,用于计算所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异;
作为模块,用于将所述分布差异,作为所述老师模型和所述学生模型之间的知识蒸馏损失函数;
迁移模块,用于根据所述知识蒸馏损失函数,将所述老师模型的中间层的数据映射关系迁移至所述学生模型的中间层上。
本申请还提供了一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现上述方法的步骤。
本申请还提供了一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现上述方法的步骤。
本申请通过自注意力机制的注意力权重表征内部结构关系,注意力权重由内部元素两两之间计算获得,无视特征之间的距离,可很好的表达内部结构关系,不机械要求大模型和小模型的类别数或者网络输出特征维度要一致,能满足不同任务类型模型的知识蒸馏训练。
附图说明
图1 本申请一实施例的基于自注意力的知识蒸馏方法流程示意图;
图2 本申请一实施例的学生模型的训练过程架构示意图;
图3 本申请一实施例的基于自注意力的知识蒸馏装置结构示意图;
图4 本申请一实施例的计算机设备内部结构示意图。
具体实施方式
为了使本申请的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本申请,并不用于限定本申请。
参照图1,本申请一实施例的基于自注意力的知识蒸馏方法,包括:
S1:将输入数据输入第一模型得到所述第一模型的中间层输出的第一特征矩阵,将所述输入数据输入第二模型得到所述第二模型的中间层输出的第二特征矩阵,其中,所述第一模型为训练好的老师模型,所述第二模型为待训练的学生模型,所述第一特征矩阵和所述第二特征矩阵具有相同的序列长度;
S2:根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布,根据所述第二特征矩阵计算所述学生模型对应的第二自注意力权重分布;
S3:计算所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异;
S4:将所述分布差异,作为所述老师模型和所述学生模型之间的知识蒸馏损失函数;
S5:根据所述知识蒸馏损失函数,将所述老师模型的中间层的数据映射关系迁移至所述学生模型的中间层上。
本申请实施例中,对老师模型和学生模型的具体结构不做限制,同一个输入数据输入到老师模型和学生模型,中间层输出的中间态数据为特征矩阵,比如两个模型的中间层分别输出的矩阵特征为feat_t和feat_s,只要feat_t和feat_s有相同的序列长度,即可实现知识蒸馏。比如,在语音识别任务中,当输入一个2s的音频时,分帧标准为20ms为一帧,步长为10ms,则共有199帧音频数据,对应的声学特征矩阵的形状为199*161,表示199帧,每帧数据的特征维数为161。将上述声学特征矩阵分别输入老师模型和学生模型时,输出的特征矩阵的形状应满足199 * N,帧数199保持不变,特征维数N根据选定的网络得出。比如全连接网络中特征维数N与全连接的节点数有关,卷积网络中特征维数N与卷积核的大小有关。
本申请通过老师模型的中间层和学生模型的中间层输出的特征矩阵作为知识蒸馏分析样本,通过计算中间数据态的特征矩阵的自注意力权重分布差异,构建知识蒸馏函数来评估老师模型和学生模型提取到的实例知识的差异性,或者评估老师模型和学生模型提取到的实例间关系分布知识的差异性。
本申请通过自注意力机制的注意力权重表征内部结构关系,注意力权重由内部元素两两之间计算获得,无视特征之间的距离,可很好的表达内部结构关系,不机械要求大模型和小模型的类别数或者网络输出特征维度要一致,能满足不同任务类型模型的知识蒸馏训练。
进一步地,所述根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布的步骤S2,包括:
S21:将所述第一特征矩阵进行转置计算,得到所述第一特征矩阵对应的第一转置矩阵;
S22:根据所述第一特征矩阵和所述第一转置矩阵,计算所述第一特征矩阵的内部特征关系;
S23:对所述第一特征矩阵的内部特征关系进行softmax函数计算,得到所述老师模型对应的第一自注意力权重分布。
本申请实施例在计算自注意力权重时,输入特征维度到输出特征的维度发生了变化,通过矩阵转置计算消除维度变化的影响。比如输入的特征矩阵维度表示为n*m ,n为序列长度,m为特征维度,通过矩阵转置计算特征矩阵的内部关系分布,即(n*m)*(m*n)=n*n,消除特征维度的影响,得到的注意力权重矩阵 n*n是一个方阵,当特征矩阵的帧数保持一致时,就可以通过KL散度公式计算两个方阵的分布差异。假如设定输入数据为x,老师模型表示为T,学生模型表示为S;T的中间层输出的特征矩阵表示为F_t,F_t=n*m,S的中间层输出的特征矩阵表示为F_s,F_s=n*p。通过矩阵转置计算特征矩阵内部关系,并通过softmax函数计算得到老师模型的自注意力权重分布,即d_t=softmax(score(F_t,F_t)),由特征矩阵n*m转换成特征方阵n*n,score()表示缩放点乘函数,
Figure 362861DEST_PATH_IMAGE001
Figure 778799DEST_PATH_IMAGE002
表示特征矩阵F_t的转置,
Figure 780253DEST_PATH_IMAGE003
表示特征矩阵F_t的特征维度。学生模型的自注意力权重分布的计算过程与老师模型的自注意力权重分布的计算过程相同。即学生模型的自注意力权重分布为d_s=softmax(score(F_s,F_s)),由特征矩阵n*p转换成特征方阵n*n。进一步地,所述老师模型使用多头注意力机制,所述根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布的步骤S2,包括:
S201:将所述第一特征矩阵按照所述多头注意力机制对应的头数进行均分,得到多个子矩阵;
S202:对第一子矩阵进行转置计算,得到所述第一子矩阵对应的第一转置子矩阵,其中,所述第一子矩阵为所述第一特征矩阵对应的多个子矩阵中的任一个;
S203:根据所述第一子矩阵和所述第一转置子矩阵,计算所述第一子矩阵的内部特征关系;
S204:根据所述第一子矩阵的内部特征关系的计算方式,计算所述第一特征矩阵的各子矩阵分别对应的内部特征关系;
S205:将所述第一特征矩阵的各子矩阵分别对应的内部特征关系,按照各子矩阵在所述第一特征矩阵中的排布次序,拼接成所述第一特征矩阵的内部特征关系;
S206:对所述第一特征矩阵的内部特征关系进行softmax函数计算,得到所述老师模型对应的第一自注意力权重分布。
本申请实施例的注意力机制采用多头注意力,以增强捕捉特征矩阵的局部结构信息。本申请通过将特征矩阵均匀等分成多个块,匹配多头注意力。举例地,T的中间层输出的特征矩阵F_t为n*h*i,其中m =h*i;S的中间层输出的特征矩阵F_s为n*h*j,其中p=h*j,h表示将特征矩阵均匀分成的块数,多头注意力的头数为h。老师模型的多头自注意力权重分布d_t= softmax(score(F_t,F_t)),特征矩阵由n*h*i转换成h*n*n。学生模型同样使用多头注意力机制时,学生模型的多头自注意力权重分布d_s=softmax(score(F_s,F_s)),特征矩阵由 n*h*j 转换成 h*n*n。在应用多头注意力权重分布的时候需要满足,m=h*i和p=h*j即单列的向量维度可以被头数h整除,便于按照头数对特征矩阵进行均匀等分。本申请其他实施例中老师模型和学生模型可选择一个使用多头注意力机制,另一个使用单头注意力机制,不作限定,只要确保两者输出的序列长度相同,即可使用本申请的知识蒸馏函数进行知识蒸馏。
进一步地,所述计算所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异的步骤S3,包括:
S31:计算所述第一自注意力权重分布和所述第一自注意力权重分布之间的KL散度损失;
S32:将所述KL散度损失作为所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异。
本申请实施例中,为缩小计算数值并减小计算量,将相似度用softmax函数或者其它函数转化为[0,1]区间的概率值,然后用KLDiv(Kullback-Leibler Divergence, KL散度)去计算KL散度损失,KL散度损失为衡量两个自注意力权重分布之间的分布差异,表示为KLdiv(d_t,d_s),以评估两个自注意力权重分布的分布差异。
进一步地,所述根据所述知识蒸馏损失函数,将所述老师模型的中间层的数据映射关系迁移至所述学生模型的中间层上的步骤S5之后,包括:
S6:获取预设的所述学生模型的任务类型;
S7:根据所述学生模型的任务类型匹配全连接层和目标函数,其中,所述全连接层连接于所述学生模型的中间层的输出端;
S8:根据所述知识蒸馏损失函数和所述目标函数,形成训练所述学生模型的总损失函数;
S9:根据所述总损失函数在训练集上训练所述学生模型。
本申请实施例中,fc(fully connect,全连接层)为模型的最后一层与最终的任务类型相关。当任务是分类任务时,fc用于分类;当任务是回归任务时,fc用于拟合回归。知识蒸馏时用前述方法构建基于注意力权重的损失函数Loss1,fc层的输出使用与fc层任务类型对应的目标函数作为损失函数TargertLoss,最终训练学生模型时的总损失函数是上述两个损失函数的和,即TotalLoss=Loss1*lambda+TargertLoss,lambda为调节参数,为大于零的实数。本申请实施例的学生模型的训练过程架构示意图如图2所示。
进一步地,所述根据所述知识蒸馏损失函数和所述目标函数,形成训练所述学生模型的总损失函数的步骤S8,包括:
S81:获取所述知识蒸馏损失函数和所述目标函数的数量级别差;
S82:根据所述数量级别差确定调节参数;
S83:根据所述调节参数、所述知识蒸馏损失函数和所述目标函数,训练所述学生模型的总损失函数。
本申请实施例根据知识蒸馏损失函数和所述目标函数的数量级别差,选择调节参数,以调节两部分函数的函数值大小对总函数的影响,以平衡两个损失函数值的大小,共同约束学生模型的训练。
参照图3,本申请一实施例的基于自注意力的知识蒸馏装置,包括:
输入模块1,用于将输入数据输入第一模型得到所述第一模型的中间层输出的第一特征矩阵,将所述输入数据输入第二模型得到所述第二模型的中间层输出的第二特征矩阵,其中,所述第一模型为训练好的老师模型,所述第二模型为待训练的学生模型,所述第一特征矩阵和所述第二特征矩阵具有相同的序列长度;
第一计算模块2,用于根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布,根据所述第二特征矩阵计算所述学生模型对应的第二自注意力权重分布;
第二计算模块3,用于计算所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异;
作为模块4,用于将所述分布差异,作为所述老师模型和所述学生模型之间的知识蒸馏损失函数;
迁移模块5,用于根据所述知识蒸馏损失函数,将所述老师模型的中间层的数据映射关系迁移至所述学生模型的中间层上。
本申请实施例中,对老师模型和学生模型的具体结构不做限制,同一个输入数据输入到老师模型和学生模型,中间层输出的中间态数据为特征矩阵,比如两个模型的中间层分别输出的矩阵特征为feat_t和feat_s,只要feat_t和feat_s有相同的序列长度,即可实现知识蒸馏。比如,在语音识别任务中,当输入一个2s的音频时,分帧标准为20ms为一帧,步长为10ms,则共有199帧音频数据,对应的声学特征矩阵的形状为199*161,表示199帧,每帧数据的特征维数为161。将上述声学特征矩阵分别输入老师模型和学生模型时,输出的特征矩阵的形状应满足199 * N,帧数199保持不变,特征维数N根据选定的网络得出。比如全连接网络中特征维数N与全连接的节点数有关,卷积网络中特征维数N与卷积核的大小有关。
本申请通过老师模型的中间层和学生模型的中间层输出的特征矩阵作为知识蒸馏分析样本,通过计算中间数据态的特征矩阵的自注意力权重分布差异,构建知识蒸馏函数来评估老师模型和学生模型提取到的实例知识的差异性,或者评估老师模型和学生模型提取到的实例间关系分布知识的差异性。
本申请通过自注意力机制的注意力权重表征内部结构关系,注意力权重由内部元素两两之间计算获得,无视特征之间的距离,可很好的表达内部结构关系,不机械要求大模型和小模型的类别数或者网络输出特征维度要一致,能满足不同任务类型模型的知识蒸馏训练。
进一步地,第一计算模块2,包括:
第一计算单元,用于将所述第一特征矩阵进行转置计算,得到所述第一特征矩阵对应的第一转置矩阵;
第二计算单元,用于根据所述第一特征矩阵和所述第一转置矩阵,计算所述第一特征矩阵的内部特征关系;
第三计算单元,用于对所述第一特征矩阵的内部特征关系进行softmax函数计算,得到所述老师模型对应的第一自注意力权重分布。
本申请实施例在计算自注意力权重时,输入特征维度到输出特征的维度发生了变化,通过矩阵转置计算消除维度变化的影响。比如输入的特征矩阵维度表示为n*m ,n为序列长度,m为特征维度,通过矩阵转置计算特征矩阵的内部关系分布,即(n*m)*(m*n)=n*n,消除特征维度的影响,得到的注意力权重矩阵 n*n是一个方阵,当特征矩阵的帧数保持一致时,就可以通过KL散度公式计算两个方阵的分布差异。假如设定输入数据为x,老师模型表示为T,学生模型表示为S;T的中间层输出的特征矩阵表示为F_t,F_t=n*m,S的中间层输出的特征矩阵表示为F_s,F_s=n*p。通过矩阵转置计算特征矩阵内部关系,并通过softmax函数计算得到老师模型的自注意力权重分布,即d_t=softmax(score(F_t,F_t)),由特征矩阵n*m转换成特征方阵n*n,score()表示缩放点乘函数,
Figure 790539DEST_PATH_IMAGE004
Figure 52893DEST_PATH_IMAGE005
表示特征矩阵F_t的转置,
Figure 780678DEST_PATH_IMAGE006
表示特征矩阵F_t的特征维度。学生模型的自注意力权重分布的计算过程与老师模型的自注意力权重分布的计算过程相同。即学生模型的自注意力权重分布为d_s=softmax(score(F_s,F_s)),由特征矩阵n*p转换成特征方阵n*n。
进一步地,所述老师模型使用多头注意力机制,第一计算模块2,包括:
均分单元,用于将所述第一特征矩阵按照所述多头注意力机制对应的头数进行均分,得到多个子矩阵;
第四计算单元,用于对第一子矩阵进行转置计算,得到所述第一子矩阵对应的第一转置子矩阵,其中,所述第一子矩阵为所述第一特征矩阵对应的多个子矩阵中的任一个;
第五计算单元,用于根据所述第一子矩阵和所述第一转置子矩阵,计算所述第一子矩阵的内部特征关系;
第六计算单元,用于根据所述第一子矩阵的内部特征关系的计算方式,计算所述第一特征矩阵的各子矩阵分别对应的内部特征关系;
拼接单元,用于将所述第一特征矩阵的各子矩阵分别对应的内部特征关系,按照各子矩阵在所述第一特征矩阵中的排布次序,拼接成所述第一特征矩阵的内部特征关系;
第七计算单元,用于对所述第一特征矩阵的内部特征关系进行softmax函数计算,得到所述老师模型对应的第一自注意力权重分布。
本申请实施例的注意力机制采用多头注意力,以增强捕捉特征矩阵的局部结构信息。本申请通过将特征矩阵均匀等分成多个块,匹配多头注意力。举例地,T的中间层输出的特征矩阵F_t为n*h*i,其中m =h*i;S的中间层输出的特征矩阵F_s为n*h*j,其中p=h*j,h表示将特征矩阵均匀分成的块数,多头注意力的头数为h。老师模型的多头自注意力权重分布d_t= softmax(score(F_t,F_t)),特征矩阵由n*h*i转换成h*n*n。学生模型同样使用多头注意力机制时,学生模型的多头自注意力权重分布d_s=softmax(score(F_s,F_s)),特征矩阵由 n*h*j 转换成 h*n*n。在应用多头注意力权重分布的时候需要满足,m=h*i和p=h*j即单列的向量维度可以被头数h整除,便于按照头数对特征矩阵进行均匀等分。本申请其他实施例中老师模型和学生模型可选择一个使用多头注意力机制,另一个使用单头注意力机制,不作限定,只要确保两者输出的序列长度相同,即可使用本申请的知识蒸馏函数进行知识蒸馏。
进一步地,第二计算模块3,包括:
第八计算单元,用于计算所述第一自注意力权重分布和所述第一自注意力权重分布之间的KL散度损失;
作为单元,用于将所述KL散度损失作为所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异。
本申请实施例中,为缩小计算数值并减小计算量,将相似度用softmax函数或者其它函数转化为[0,1]区间的概率值,然后用KLDiv(Kullback-Leibler Divergence, KL散度)去计算KL散度损失,KL散度损失为衡量两个自注意力权重分布之间的分布差异,表示为KLdiv(d_t,d_s),以评估两个自注意力权重分布的分布差异。
进一步地,基于自注意力的知识蒸馏装置,包括:
获取模块,用于获取预设的所述学生模型的任务类型;
匹配模块,用于根据所述学生模型的任务类型匹配全连接层和目标函数,其中,所述全连接层连接于所述学生模型的中间层的输出端;
形成模块,用于根据所述知识蒸馏损失函数和所述目标函数,形成训练所述学生模型的总损失函数;
训练模块,用于根据所述总损失函数在训练集上训练所述学生模型。
本申请实施例中,fc(fully connect,全连接层)为模型的最后一层与最终的任务类型相关。当任务是分类任务时,fc用于分类;当任务是回归任务时,fc用于拟合回归。知识蒸馏时用前述方法构建基于注意力权重的损失函数Loss1,fc层的输出使用与fc层任务类型对应的目标函数作为损失函数TargertLoss,最终训练学生模型时的总损失函数是上述两个损失函数的和,即TotalLoss=Loss1*lambda+TargertLoss,lambda为调节参数,为大于零的实数。本申请实施例的学生模型的训练过程架构示意图如图2所示。
进一步地,形成模块,包括:
获取单元,用于获取所述知识蒸馏损失函数和所述目标函数的数量级别差;
确定单元,用于根据所述数量级别差确定调节参数;
训练单元,用于根据所述调节参数、所述知识蒸馏损失函数和所述目标函数,训练所述学生模型的总损失函数。
本申请实施例根据知识蒸馏损失函数和所述目标函数的数量级别差,选择调节参数,以调节两部分函数的函数值大小对总函数的影响,以平衡两个损失函数值的大小,共同约束学生模型的训练。
参照图4,本申请实施例中还提供一种计算机设备,该计算机设备可以是服务器,其内部结构可以如图4所示。该计算机设备包括通过***总线连接的处理器、存储器、网络接口和数据库。其中,该计算机设计的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作***、计算机程序和数据库。该内存器为非易失性存储介质中的操作***和计算机程序的运行提供环境。该计算机设备的数据库用于存储基于自注意力的知识蒸馏过程需要的所有数据。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机程序被处理器执行时以实现基于自注意力的知识蒸馏方法。
上述处理器执行上述基于自注意力的知识蒸馏方法,包括:将输入数据输入第一模型得到所述第一模型的中间层输出的第一特征矩阵,将所述输入数据输入第二模型得到所述第二模型的中间层输出的第二特征矩阵,其中,所述第一模型为训练好的老师模型,所述第二模型为待训练的学生模型,所述第一特征矩阵和所述第二特征矩阵具有相同的序列长度;根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布,根据所述第二特征矩阵计算所述学生模型对应的第二自注意力权重分布;计算所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异;将所述分布差异,作为所述老师模型和所述学生模型之间的知识蒸馏损失函数;根据所述知识蒸馏损失函数,将所述老师模型的中间层的数据映射关系迁移至所述学生模型的中间层上。
上述计算机设备,通过自注意力机制的注意力权重表征内部结构关系,注意力权重由内部元素两两之间计算获得,无视特征之间的距离,可很好的表达内部结构关系,不机械要求大模型和小模型的类别数或者网络输出特征维度要一致,能满足不同任务类型模型的知识蒸馏训练。
本领域技术人员可以理解,图4中示出的结构,仅仅是与本申请方案相关的部分结构的框图,并不构成对本申请方案所应用于其上的计算机设备的限定。
本申请一实施例还提供一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现基于自注意力的知识蒸馏方法,包括:将输入数据输入第一模型得到所述第一模型的中间层输出的第一特征矩阵,将所述输入数据输入第二模型得到所述第二模型的中间层输出的第二特征矩阵,其中,所述第一模型为训练好的老师模型,所述第二模型为待训练的学生模型,所述第一特征矩阵和所述第二特征矩阵具有相同的序列长度;根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布,根据所述第二特征矩阵计算所述学生模型对应的第二自注意力权重分布;计算所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异;将所述分布差异,作为所述老师模型和所述学生模型之间的知识蒸馏损失函数;根据所述知识蒸馏损失函数,将所述老师模型的中间层的数据映射关系迁移至所述学生模型的中间层上。
上述计算机可读存储介质,通过自注意力机制的注意力权重表征内部结构关系,注意力权重由内部元素两两之间计算获得,无视特征之间的距离,可很好的表达内部结构关系,不机械要求大模型和小模型的类别数或者网络输出特征维度要一致,能满足不同任务类型模型的知识蒸馏训练。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,上述的计算机程序可存储于一非易失性计算机可读取存储介质中,该计算机程序在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的和实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可以包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双速据率SDRAM(SSRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、装置、物品或者方法不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、装置、物品或者方法所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、装置、物品或者方法中还存在另外的相同要素。
以上所述仅为本申请的优选实施例,并非因此限制本申请的专利范围,凡是利用本申请说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本申请的专利保护范围内。

Claims (10)

1.一种基于自注意力的知识蒸馏方法,其特征在于,包括:
将输入数据输入第一模型得到所述第一模型的中间层输出的第一特征矩阵,将所述输入数据输入第二模型得到所述第二模型的中间层输出的第二特征矩阵,其中,所述第一模型为训练好的老师模型,所述第二模型为待训练的学生模型,所述第一特征矩阵和所述第二特征矩阵具有相同的序列长度;
根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布,根据所述第二特征矩阵计算所述学生模型对应的第二自注意力权重分布;
计算所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异;
将所述分布差异,作为所述老师模型和所述学生模型之间的知识蒸馏损失函数;
根据所述知识蒸馏损失函数,将所述老师模型的中间层的数据映射关系迁移至所述学生模型的中间层上。
2.根据权利要求1所述的基于自注意力的知识蒸馏方法,其特征在于,所述根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布的步骤,包括:
将所述第一特征矩阵进行转置计算,得到所述第一特征矩阵对应的第一转置矩阵;
根据所述第一特征矩阵和所述第一转置矩阵,计算所述第一特征矩阵的内部特征关系;
对所述第一特征矩阵的内部特征关系进行softmax函数计算,得到所述老师模型对应的第一自注意力权重分布。
3.根据权利要求1所述的基于自注意力的知识蒸馏方法,其特征在于,所述老师模型使用多头注意力机制,所述根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布的步骤,包括:
将所述第一特征矩阵按照所述多头注意力机制对应的头数进行均分,得到多个子矩阵;
对第一子矩阵进行转置计算,得到所述第一子矩阵对应的第一转置子矩阵,其中,所述第一子矩阵为所述第一特征矩阵对应的多个子矩阵中的任一个;
根据所述第一子矩阵和所述第一转置子矩阵,计算所述第一子矩阵的内部特征关系;
根据所述第一子矩阵的内部特征关系的计算方式,计算所述第一特征矩阵的各子矩阵分别对应的内部特征关系;
将所述第一特征矩阵的各子矩阵分别对应的内部特征关系,按照各子矩阵在所述第一特征矩阵中的排布次序,拼接成所述第一特征矩阵的内部特征关系;
对所述第一特征矩阵的内部特征关系进行softmax函数计算,得到所述老师模型对应的第一自注意力权重分布。
4.根据权利要求2或3所述的基于自注意力的知识蒸馏方法,其特征在于,所述计算所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异的步骤,包括:
计算所述第一自注意力权重分布和所述第一自注意力权重分布之间的KL散度损失;
将所述KL散度损失作为所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异。
5.根据权利要求4所述的基于自注意力的知识蒸馏方法,其特征在于,所述根据所述知识蒸馏损失函数,将所述老师模型的中间层的数据映射关系迁移至所述学生模型的中间层上的步骤之后,包括:
获取预设的所述学生模型的任务类型;
根据所述学生模型的任务类型匹配全连接层和目标函数,其中,所述全连接层连接于所述学生模型的中间层的输出端;
根据所述知识蒸馏损失函数和所述目标函数,形成训练所述学生模型的总损失函数;
根据所述总损失函数在训练集上训练所述学生模型。
6.根据权利要求5所述的基于自注意力的知识蒸馏方法,其特征在于,所述根据所述知识蒸馏损失函数和所述目标函数,形成训练所述学生模型的总损失函数的步骤,包括:
获取所述知识蒸馏损失函数和所述目标函数的数量级别差;
根据所述数量级别差确定调节参数;
根据所述调节参数、所述知识蒸馏损失函数和所述目标函数,训练所述学生模型的总损失函数。
7.一种基于自注意力的知识蒸馏装置,其特征在于,包括:
输入模块,用于将输入数据输入第一模型得到所述第一模型的中间层输出的第一特征矩阵,将所述输入数据输入第二模型得到所述第二模型的中间层输出的第二特征矩阵,其中,所述第一模型为训练好的老师模型,所述第二模型为待训练的学生模型,所述第一特征矩阵和所述第二特征矩阵具有相同的序列长度;
第一计算模块,用于根据所述第一特征矩阵计算所述老师模型对应的第一自注意力权重分布,根据所述第二特征矩阵计算所述学生模型对应的第二自注意力权重分布;
第二计算模块,用于计算所述第一自注意力权重分布和所述第二自注意力权重分布之间的分布差异;
作为模块,用于将所述分布差异,作为所述老师模型和所述学生模型之间的知识蒸馏损失函数;
迁移模块,用于根据所述知识蒸馏损失函数,将所述老师模型的中间层的数据映射关系迁移至所述学生模型的中间层上。
8.根据权利要求7所述的基于自注意力的知识蒸馏装置,其特征在于,所述第一计算模块,包括:
第一计算单元,用于将所述第一特征矩阵进行转置计算,得到所述第一特征矩阵对应的第一转置矩阵;
第二计算单元,用于根据所述第一特征矩阵和所述第一转置矩阵,计算所述第一特征矩阵的内部特征关系;
第三计算单元,用于对所述第一特征矩阵的内部特征关系进行softmax函数计算,得到所述老师模型对应的第一自注意力权重分布。
9.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1至6中任一项所述方法的步骤。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至6中任一项所述的方法的步骤。
CN202110059942.7A 2021-01-18 2021-01-18 基于自注意力的知识蒸馏方法、装置和计算机设备 Active CN112365385B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110059942.7A CN112365385B (zh) 2021-01-18 2021-01-18 基于自注意力的知识蒸馏方法、装置和计算机设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110059942.7A CN112365385B (zh) 2021-01-18 2021-01-18 基于自注意力的知识蒸馏方法、装置和计算机设备

Publications (2)

Publication Number Publication Date
CN112365385A true CN112365385A (zh) 2021-02-12
CN112365385B CN112365385B (zh) 2021-06-01

Family

ID=74535011

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110059942.7A Active CN112365385B (zh) 2021-01-18 2021-01-18 基于自注意力的知识蒸馏方法、装置和计算机设备

Country Status (1)

Country Link
CN (1) CN112365385B (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113673254A (zh) * 2021-08-23 2021-11-19 东北林业大学 基于相似度保持的知识蒸馏的立场检测方法
CN114819188A (zh) * 2022-05-19 2022-07-29 北京百度网讯科技有限公司 模型训练方法、装置、电子设备及可读存储介质
CN116778300A (zh) * 2023-06-25 2023-09-19 北京数美时代科技有限公司 一种基于知识蒸馏的小目标检测方法、***和存储介质
CN117116408A (zh) * 2023-10-25 2023-11-24 湖南科技大学 一种面向电子病历解析的关系抽取方法

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111062489A (zh) * 2019-12-11 2020-04-24 北京知道智慧信息技术有限公司 一种基于知识蒸馏的多语言模型压缩方法、装置
CN111554268A (zh) * 2020-07-13 2020-08-18 腾讯科技(深圳)有限公司 基于语言模型的语言识别方法、文本分类方法和装置
CN111652066A (zh) * 2020-04-30 2020-09-11 北京航空航天大学 基于多自注意力机制深度学习的医疗行为识别方法
CN111767711A (zh) * 2020-09-02 2020-10-13 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台
CN111950643A (zh) * 2020-08-18 2020-11-17 创新奇智(上海)科技有限公司 模型训练方法、图像分类方法及对应装置
CN111967224A (zh) * 2020-08-18 2020-11-20 深圳市欢太科技有限公司 对话文本的处理方法、装置、电子设备及存储介质

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111062489A (zh) * 2019-12-11 2020-04-24 北京知道智慧信息技术有限公司 一种基于知识蒸馏的多语言模型压缩方法、装置
CN111652066A (zh) * 2020-04-30 2020-09-11 北京航空航天大学 基于多自注意力机制深度学习的医疗行为识别方法
CN111554268A (zh) * 2020-07-13 2020-08-18 腾讯科技(深圳)有限公司 基于语言模型的语言识别方法、文本分类方法和装置
CN111950643A (zh) * 2020-08-18 2020-11-17 创新奇智(上海)科技有限公司 模型训练方法、图像分类方法及对应装置
CN111967224A (zh) * 2020-08-18 2020-11-20 深圳市欢太科技有限公司 对话文本的处理方法、装置、电子设备及存储介质
CN111767711A (zh) * 2020-09-02 2020-10-13 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
WENHUI WANG 等: "MINILM: Deep Self-Attention Distillation for Task-Agnostic Compression of Pre-Trained Transformers", 《HTTPS://ARXIV.ORG/ABS/2002.10957V2》 *

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113673254A (zh) * 2021-08-23 2021-11-19 东北林业大学 基于相似度保持的知识蒸馏的立场检测方法
CN113673254B (zh) * 2021-08-23 2022-06-07 东北林业大学 基于相似度保持的知识蒸馏的立场检测方法
CN114819188A (zh) * 2022-05-19 2022-07-29 北京百度网讯科技有限公司 模型训练方法、装置、电子设备及可读存储介质
CN116778300A (zh) * 2023-06-25 2023-09-19 北京数美时代科技有限公司 一种基于知识蒸馏的小目标检测方法、***和存储介质
CN116778300B (zh) * 2023-06-25 2023-12-05 北京数美时代科技有限公司 一种基于知识蒸馏的小目标检测方法、***和存储介质
CN117116408A (zh) * 2023-10-25 2023-11-24 湖南科技大学 一种面向电子病历解析的关系抽取方法
CN117116408B (zh) * 2023-10-25 2024-01-26 湖南科技大学 一种面向电子病历解析的关系抽取方法

Also Published As

Publication number Publication date
CN112365385B (zh) 2021-06-01

Similar Documents

Publication Publication Date Title
CN112365385B (zh) 基于自注意力的知识蒸馏方法、装置和计算机设备
CN109034378B (zh) 神经网络的网络表示生成方法、装置、存储介质和设备
CN111177345B (zh) 基于知识图谱的智能问答方法、装置和计算机设备
CN108763535A (zh) 信息获取方法及装置
CN112149797B (zh) 神经网络结构优化方法和装置、电子设备
CN113673698A (zh) 适用于bert模型的蒸馏方法、装置、设备及存储介质
CN111429923B (zh) 说话人信息提取模型的训练方法、装置和计算机设备
CN111078847A (zh) 电力用户意图识别方法、装置、计算机设备和存储介质
WO2021082488A1 (zh) 基于文本匹配的智能面试方法、装置和计算机设备
CN113190688A (zh) 基于逻辑推理和图卷积的复杂网络链接预测方法及***
CN111259113A (zh) 文本匹配方法、装置、计算机可读存储介质和计算机设备
CN112699215B (zh) 基于胶囊网络与交互注意力机制的评级预测方法及***
CN114782775A (zh) 分类模型的构建方法、装置、计算机设备及存储介质
CN112613555A (zh) 基于元学习的目标分类方法、装置、设备和存储介质
JP2018185771A (ja) 文ペア分類装置、文ペア分類学習装置、方法、及びプログラム
CN113610163A (zh) 一种基于知识蒸馏的轻量级苹果叶片病害识别方法
CN114580388A (zh) 数据处理方法、对象预测方法、相关设备及存储介质
CN115374792A (zh) 联合预训练和图神经网络的政策文本标注方法及***
CN113223504B (zh) 声学模型的训练方法、装置、设备和存储介质
CN112634870B (zh) 关键词检测方法、装置、设备和存储介质
CN115905848A (zh) 基于多模型融合的化工过程故障诊断方法及***
CN113792110A (zh) 一种基于社交物联网的设备信任值评估方法
CN113486140A (zh) 知识问答的匹配方法、装置、设备及存储介质
CN109034387A (zh) 一种基于伪逆学习快速训练自编码器的近似方法
CN116738983A (zh) 模型进行金融领域任务处理的词嵌入方法、装置、设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant