CN112529209A - 模型训练方法、装置以及计算机可读存储介质 - Google Patents

模型训练方法、装置以及计算机可读存储介质 Download PDF

Info

Publication number
CN112529209A
CN112529209A CN202011427624.3A CN202011427624A CN112529209A CN 112529209 A CN112529209 A CN 112529209A CN 202011427624 A CN202011427624 A CN 202011427624A CN 112529209 A CN112529209 A CN 112529209A
Authority
CN
China
Prior art keywords
model
data processing
sample
training
processing model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202011427624.3A
Other languages
English (en)
Inventor
孟嘉琪
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shanghai Yuncong Enterprise Development Co ltd
Original Assignee
Shanghai Yuncong Enterprise Development Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shanghai Yuncong Enterprise Development Co ltd filed Critical Shanghai Yuncong Enterprise Development Co ltd
Priority to CN202011427624.3A priority Critical patent/CN112529209A/zh
Publication of CN112529209A publication Critical patent/CN112529209A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Complex Calculations (AREA)

Abstract

本发明涉及机器学习技术领域,具体提供了一种模型训练方法、装置以及计算机可读存储介质,旨在解决如何提高模型训练效果的技术问题。根据本发明实施例的方法,可以利用初始训练集对预设的数据处理模型训练得到第一数据处理模型;获取第一数据处理模型分别与每个第二数据处理模型在测试集上的第一模型损失差值,第二数据处理模型是根据初始训练集下不同的子训练集训练得到的,不同的子训练集之间相差一个或多个不同的被删除样本;最后根据该差值获取异常样本以对初始训练集进行优化,利用优化后的初始训练集对第一数据处理模型进行训练。基于上述步骤,本发明不仅能够从训练集中快速且准确地筛选出异常样本,还极大地提高了模型训练效果。

Description

模型训练方法、装置以及计算机可读存储介质
技术领域
本发明涉及机器学习技术领域,具体涉及一种模型训练方法、装置以及计算机可读存储介质。
背景技术
机器学习技术领域中的监督学习主要是利用训练样本与样本标签对模型进行训练,而为了提高模型的训练效果,需要使用数量级较大的训练样本如百万级别的训练样本并且提前为每个训练样本分别标注好准确的样本标签,才能保证训练好的模型具备较高的模型性能。例如:利用百万级别的训练样本以及每个训练样本各自对应的类别标签,对数据分类模型进行训练,以使训练好的数据分类模型具备较高的分类性能。由于训练样本的数量级过大,在对训练样本进行标签标注时,无法保证对每个训练样本都进行准确地标签标注,如果利用这些标签错误的噪声样本进行模型训练,会降低模型的训练效果。
发明内容
为了克服上述缺陷,提出了本发明,以提供解决或至少部分地解决如何提高模型训练效果的技术问题的模型训练方法、装置以及计算机可读存储介质。
第一方面,提供一种模型训练方法,所述方法包括:
利用初始训练集对预设的数据处理模型进行训练,获取第一数据处理模型;
利用测试集对所述第一数据处理模型以及多个第二数据处理模型分别进行测试,获取所述第一数据处理模型分别与每个所述第二数据处理模型在所述测试集上的第一模型损失差值;
根据所述第一模型损失差值获取所述初始训练集内的异常样本,并且根据所述异常样本对所述初始训练集进行样本调整,获取优化的训练集;
利用所述优化的训练集对所述第一数据处理模型进行训练,以获取最终的数据处理模型;
其中,不同的第二数据处理模型被配置成根据所述初始训练集下不同的子训练集训练得到,所述不同的子训练集之间相差一个或多个不同的被删除样本。
在上述模型训练方法的一个技术方案中,“获取所述第一数据处理模型分别与每个所述第二数据处理模型在所述测试集上的第一模型损失差值”的步骤具体包括:
获取利用所述初始训练集对所述预设的数据处理模型进行训练后,得到的所述第一数据处理模型的多个备选数据处理模型;
利用所述测试集对所述多个备选数据处理模型分别进行测试,以获取最优的备选数据处理模型作为最终的第一数据处理模型并且获取所述最终的第一数据处理模型的第一模型参数;
根据所述第一模型参数,拟合利用所述测试集对当前被删除样本对应的第二数据处理模型的多个备选数据处理模型分别进行测试,获取最优的备选数据处理模型作为最终的第二数据处理模型并且获取所述最终的第二数据处理模型的第二模型参数,其中,所述多个备选数据处理模型是利用所述当前被删除样本对应的子训练集对所述预设的数据处理模型进行训练得到的;
采用稳健统计方法,对所述第二模型参数与所述最终的第二数据处理模型在所述测试集上的模型损失进行影响分析,以获取所述当前被删除样本对应的第一模型损失差值。
在上述模型训练方法的一个技术方案中,所述方法包括根据所述第一模型参数并且按照下式所示的方法,拟合得到所述第二模型参数:
根据所述第一模型参数并且按照下式所示的方法,拟合得到所述第二模型参数:
Figure BDA0002819692220000021
其中,所述
Figure BDA0002819692220000022
表示拟合得到的所述最终的第二数据处理模型的第二模型参数,所述zdel表示所述当前被删除样本;所述
Figure BDA0002819692220000023
表示所述第一模型参数;所述L表示对预设的数据处理模型进行训练与测试时使用的损失函数;所述zi表示所述训练集内的第i个样本且zi=(xi,yi),xi表示样本zi中的图像样本,yi表示所述图像样本的标签,i=1,...,n;所述ε1表示预设的所述当前被删除样本zdel的样本权重且
Figure BDA0002819692220000031
在上述模型训练方法的一个技术方案中,“获取所述当前被删除样本对应的第一模型损失差值”的步骤具体包括:
基于稳健统计方法中的影响函数理论,构建下式所示的所述第二模型参数与所述最终的第二数据处理模型在所述测试集上的模型损失的影响函数,根据所述影响函数计算所述第一模型损失差值:
Figure BDA0002819692220000032
其中,所述Γup,loss(zdel,ztest)表示所述第一模型损失差值,所述ztest表示所述测试集,所述L表示对预设的数据处理模型进行训练与测试时使用的损失函数,所述
Figure BDA0002819692220000033
表示根据所述损失函数L计算出的损失值对所述模型参数θ求梯度,所述T表示对经
Figure BDA0002819692220000034
计算出的梯度向量进行转置;所述
Figure BDA0002819692220000035
表示所述第二模型参数;所述zdel表示所述当前被删除样本,所述
Figure BDA0002819692220000036
表示所述最终的第二数据处理模型的经验风险的Hessian矩阵且
Figure BDA0002819692220000037
在上述模型训练方法的一个技术方案中,“根据所述第一模型损失差值获取所述初始训练集内的异常样本”的步骤具体包括:
对所述第一模型损失差值进行由负至正的逆向排序;
根据逆向排序的结果,选取排序顺序小于等于预设顺序值的第一模型损失差值;
根据选取到的第一模型损失差值对应的第二数据处理模型,获取在训练所述第二数据处理模型时的被删除样本,将所述被删除样本作为异常样本。
在上述模型训练方法的一个技术方案中,“根据所述异常样本对所述训练集进行样本调整”的步骤具体包括:
获取所述异常样本的样本标签;
判断所述样本标签是否正确;
若正确,则删除所述异常样本;
若不正确,则修正所述异常样本的样本标签。
在上述模型训练方法的一个技术方案中,所述方法还包括通过下列方式获取对抗训练集,以便利用所述优化后的训练集与所述对抗训练集对预设的生成对抗网络模型进行训练:
利用所述优化后的训练集对所述预设的数据处理模型进行训练,获取第三数据处理模型;
利用所述测试集对所述第三数据处理模型以及多个第四数据处理模型分别进行测试,获取所述第三数据处理模型分别与每个所述第四数据处理模型在所述测试集上的第二模型损失差值;其中,不同的第四数据处理模型被配置成根据所述优化后的训练集下不同的子训练集训练得到,所述不同的子训练集之间相差一个或多个不同的被扰动样本;
根据每个所述第四数据处理模型各自对应的第二模型损失差值的变化趋势,调整每个所述第四数据处理模型各自对应的被扰动样本的扰动量,以获取每个所述第四数据处理模型各自对应的最大的第二模型损失差值;
获取所述最大的第二模型损失差值对应的扰动量与被扰动样本,并且根据所述扰动量对所述被扰动样本进行扰动形成新的样本,以根据所述新的样本构建所述对抗训练集。
在上述模型训练方法的一个技术方案中,“获取所述第三数据处理模型分别与每个所述第四数据处理模型在所述测试集上的第二模型损失差值”的步骤具体包括:
获取利用所述优化后的训练集对所述预设的数据处理模型进行训练后,得到的所述第三数据处理模型的多个备选数据处理模型;
利用所述测试集对所述多个备选数据处理模型分别进行测试,以获取最优的备选数据处理模型作为最终的第三数据处理模型并且获取所述最终的第三数据处理模型的第三模型参数;
根据所述第三模型参数,拟合利用所述测试集对当前被扰动样本对应的第四数据处理模型的多个备选数据处理模型分别进行测试,获取最优的备选数据处理模型作为最终的第四数据处理模型并且获取所述最终的第四数据处理模型的第四模型参数,其中,所述多个备选数据处理模型是利用所述当前被扰动样本对应的子训练集对所述预设的数据处理模型进行训练得到的;
采用稳健统计方法,对所述第四模型参数与所述最终的第四数据处理模型在所述测试集上的模型损失进行影响分析,以获取所述当前被扰动样本对应的第二模型损失差值。
在上述模型训练方法的一个技术方案中,所述方法包括根据所述第三模型参数并且按照下式所示的方法,拟合得到所述第四模型参数:
根据所述第三模型参数并且按照下式所示的方法,拟合得到所述第四模型参数:
Figure BDA0002819692220000051
其中,所述
Figure BDA0002819692220000052
表示拟合得到的所述最终的第四数据处理模型的第四模型参数,所述zδ表示对当前被扰动样本z增加扰动量δ后形成的新的样本且zδ=(x+δ,y),x表示样本zδ中的图像样本,y表示所述图像样本x的标签;所述
Figure BDA0002819692220000053
表示所述第三模型参数;所述zi表示所述训练集内第i个样本且zi=(xi,yi),xi表示样本zi中的图像样本,yi表示所述图像样本xi的标签,i=1,...,n;所述ε2表示预设的所述当前被扰动样本z的样本权重且
Figure BDA0002819692220000054
在上述模型训练方法的一个技术方案中,“获取所述当前被扰动样本对应的第二模型损失差值”的步骤具体包括:
基于稳健统计方法中的影响函数理论,构建下式所示的所述第四模型参数与所述最终的第四模型参数在所述测试集上的模型损失的影响函数,根据所述影响函数计算所述第二模型损失差值:
Figure BDA0002819692220000055
其中,所述Γpert,loss(z,ztest)表示所述第二模型损失差值,所述ztest表示所述测试集,所述L表示对预设的数据处理模型进行训练与测试时使用的损失函数,所述
Figure BDA0002819692220000056
表示根据所述损失函数L计算出的损失值对所述模型参数θ求梯度,所述T表示对经
Figure BDA0002819692220000057
计算出的梯度向量进行转置;所述
Figure BDA0002819692220000058
表示所述第四模型参数;所述z表示所述当前被扰动样本,所述
Figure BDA0002819692220000059
表示所述最终的第四数据处理模型的经验风险的Hessian矩阵且
Figure BDA00028196922200000510
所述
Figure BDA00028196922200000511
表示在样本z被增加扰动量δ之前以及之后根据所述损失函数L计算出的损失值的差值在图像样本x处对应的一阶泰勒展开式。
第二方面,提供一种模型训练装置,所述装置包括:
第一数据处理模型获取模块,其被配置成利用初始训练集对预设的数据处理模型进行训练,获取第一数据处理模型;
第一损失差值获取模块,其被配置成利用测试集对所述第一数据处理模型以及多个第二数据处理模型分别进行测试,获取所述第一数据处理模型分别与每个所述第二数据处理模型在所述测试集上的第一模型损失差值;
训练集优化模块,其被配置成根据所述第一模型损失差值获取所述初始训练集内的异常样本,并且根据所述异常样本对所述初始训练集进行样本调整,获取优化的训练集;
模型训练模块,其被配置成利用所述优化的训练集对所述第一数据处理模型进行训练,以获取最终的数据处理模型;
其中,不同的第二数据处理模型被配置成根据所述初始训练集下不同的子训练集训练得到,所述不同的子训练集之间相差一个或多个不同的被删除样本。
在上述模型训练装置的一个技术方案中,所述第一损失差值获取模块包括第一备选模型获取单元、第一参数获取单元、第二参数获取单元和第一损失差值获取单元;
所述第一备选模型获取单元被配置成获取利用所述初始训练集对所述预设的数据处理模型进行训练后,得到的所述第一数据处理模型的多个备选数据处理模型;
所述第一参数获取单元被配置成利用所述测试集对所述多个备选数据处理模型分别进行测试,以获取最优的备选数据处理模型作为最终的第一数据处理模型并且获取所述最终的第一数据处理模型的第一模型参数;
所述第二参数获取单元被配置成根据所述第一模型参数,拟合利用所述测试集对当前被删除样本对应的第二数据处理模型的多个备选数据处理模型分别进行测试,获取最优的备选数据处理模型作为最终的第二数据处理模型并且获取所述最终的第二数据处理模型的第二模型参数,其中,所述多个备选数据处理模型是利用所述当前被删除样本对应的子训练集对所述预设的数据处理模型进行训练得到的;
所述第一损失差值获取单元被配置成采用稳健统计方法,对所述第二模型参数与所述最终的第二数据处理模型在所述测试集上的模型损失进行影响分析,以获取所述当前被删除样本对应的第一模型损失差值。
在上述模型训练装置的一个技术方案中,所述第二参数获取单元被进一步配置成根据所述第一模型参数并且按照下式所示的方法,拟合得到所述第二模型参数:
Figure BDA0002819692220000071
其中,所述
Figure BDA0002819692220000072
表示拟合得到的所述最终的第二数据处理模型的第二模型参数,所述zdel表示所述当前被删除样本;所述
Figure BDA0002819692220000073
表示所述第一模型参数;所述L表示对预设的数据处理模型进行训练与测试时使用的损失函数;所述zi表示所述训练集内的第i个样本且zi=(xi,yi),xi表示样本zi中的图像样本,yi表示所述图像样本的标签,i=1,...,n;所述ε1表示预设的所述当前被删除样本zdel的样本权重且
Figure BDA0002819692220000074
在上述模型训练装置的一个技术方案中,所述第一损失差值获取单元被进一步配置成基于稳健统计方法中的影响函数理论,构建下式所示的所述第二模型参数与所述最终的第二数据处理模型在所述测试集上的模型损失的影响函数,根据所述影响函数计算所述第一模型损失差值:
Figure BDA0002819692220000075
其中,所述Γup,loss(zdel,ztest)表示所述第一模型损失差值,所述ztest表示所述测试集,所述L表示对预设的数据处理模型进行训练与测试时使用的损失函数,所述
Figure BDA0002819692220000076
表示根据所述损失函数L计算出的损失值对所述模型参数θ求梯度,所述T表示对经
Figure BDA0002819692220000077
计算出的梯度向量进行转置;所述
Figure BDA0002819692220000078
表示所述第二模型参数;所述zdel表示所述当前被删除样本,所述
Figure BDA0002819692220000079
表示所述最终的第二数据处理模型的经验风险的Hessian矩阵且
Figure BDA0002819692220000081
在上述模型训练装置的一个技术方案中,所述训练集优化模块被进一步配置成执行以下操作:
对所述第一模型损失差值进行由负至正的逆向排序;
根据逆向排序的结果,选取排序顺序小于等于预设顺序值的第一模型损失差值;
根据选取到的第一模型损失差值对应的第二数据处理模型,获取在训练所述第二数据处理模型时的被删除样本,将所述被删除样本作为异常样本。
在上述模型训练装置的一个技术方案中,所述训练集优化模块被进一步配置成执行以下操作:
获取所述异常样本的样本标签;
判断所述样本标签是否正确;
若正确,则删除所述异常样本;
若不正确,则修正所述异常样本的样本标签。
在上述模型训练装置的一个技术方案中,所述装置还包括:
第三数据处理模型获取模块,其被配置成利用所述优化后的训练集对所述预设的数据处理模型进行训练,获取第三数据处理模型;
第二损失差值获取模块,其被配置成利用所述测试集对所述第三数据处理模型以及多个第四数据处理模型分别进行测试,获取所述第三数据处理模型分别与每个所述第四数据处理模型在所述测试集上的第二模型损失差值;其中,不同的第四数据处理模型被配置成根据所述优化后的训练集下不同的子训练集训练得到,所述不同的子训练集之间相差一个或多个不同的被扰动样本;
第三损失差值获取模块,其被配置成根据每个所述第四数据处理模型各自对应的第二模型损失差值的变化趋势,调整每个所述第四数据处理模型各自对应的被扰动样本的扰动量,以获取每个所述第四数据处理模型各自对应的最大的第二模型损失差值;
对抗训练集获取模块,其被配置成获取所述最大的第二模型损失差值对应的扰动量与被扰动样本,并且根据所述扰动量对所述被扰动样本进行扰动形成新的样本,以根据所述新的样本构建所述对抗训练集。
在上述模型训练装置的一个技术方案中,所述第二损失差值获取模块包括第二备选模型获取单元、第三参数获取单元、第四参数获取单元和第二损失差值获取单元;
所述第二备选模型获取单元被配置成获取利用所述优化后的训练集对所述预设的数据处理模型进行训练后,得到的所述第三数据处理模型的多个备选数据处理模型;
所述第三参数获取单元被配置成利用所述测试集对所述多个备选数据处理模型分别进行测试,以获取最优的备选数据处理模型作为最终的第三数据处理模型并且获取所述最终的第三数据处理模型的第三模型参数;
所述第四参数获取单元被配置成根据所述第三模型参数,拟合利用所述测试集对当前被扰动样本对应的第四数据处理模型的多个备选数据处理模型分别进行测试,获取最优的备选数据处理模型作为最终的第四数据处理模型并且获取所述最终的第四数据处理模型的第四模型参数,其中,所述多个备选数据处理模型是利用所述当前被扰动样本对应的子训练集对所述预设的数据处理模型进行训练得到的;
所述第二损失差值获取单元被配置成采用稳健统计方法,对所述第四模型参数与所述最终的第四数据处理模型在所述测试集上的模型损失进行影响分析,以获取所述当前被扰动样本对应的第二模型损失差值。
在上述模型训练装置的一个技术方案中,所述第四参数获取单元被进一步配置成根据所述第三模型参数并且按照下式所示的方法,拟合得到所述第四模型参数:
Figure BDA0002819692220000091
其中,所述
Figure BDA0002819692220000092
表示拟合得到的所述最终的第四数据处理模型的第四模型参数,所述zδ表示对当前被扰动样本z增加扰动量δ后形成的新的样本且zδ=(x+δ,y),x表示样本zδ中的图像样本,y表示所述图像样本x的标签;所述
Figure BDA0002819692220000093
表示所述第三模型参数;所述zi表示所述训练集内第i个样本且zi=(xi,yi),xi表示样本zi中的图像样本,yi表示所述图像样本xi的标签,i=1,...,n;所述ε2表示预设的所述当前被扰动样本z的样本权重且
Figure BDA0002819692220000101
在上述模型训练装置的一个技术方案中,所述第二损失差值获取单元被进一步配置成基于稳健统计方法中的影响函数理论,构建下式所示的所述第四模型参数与所述最终的第四模型参数在所述测试集上的模型损失的影响函数,根据所述影响函数计算所述第二模型损失差值:
Figure BDA0002819692220000102
其中,所述Γpert,loss(z,ztest)表示所述第二模型损失差值,所述ztest表示所述测试集,所述L表示对预设的数据处理模型进行训练与测试时使用的损失函数,所述
Figure BDA0002819692220000103
表示根据所述损失函数L计算出的损失值对所述模型参数θ求梯度,所述T表示对经
Figure BDA0002819692220000104
计算出的梯度向量进行转置;所述
Figure BDA0002819692220000105
表示所述第四模型参数;所述z表示所述当前被扰动样本,所述
Figure BDA0002819692220000106
表示所述最终的第四数据处理模型的经验风险的Hessian矩阵且
Figure BDA0002819692220000107
所述
Figure BDA0002819692220000108
表示在样本z被增加扰动量δ之前以及之后根据所述损失函数L计算出的损失值的差值在图像样本x处对应的一阶泰勒展开式。
第三方面,提供一种控制装置,该控制装置包括处理器和存储装置,所述存储装置适于存储多条程序代码,所述程序代码适于由所述处理器加载并运行以执行上述模型训练方法的技术方案中任一项技术方案所述的模型训练方法。
第四方面,提供一种计算机可读存储介质,该计算机可读存储介质其中存储有多条程序代码,所述程序代码适于由处理器加载并运行以执行上述模型训练方法的技术方案中任一项技术方案所述的模型训练方法。
本发明上述一个或多个技术方案,至少具有如下一种或多种有益效果:
在实施本发明的技术方案中,可以先利用初始训练集对预设的数据处理模型进行训练,获取第一数据处理模型;然后利用测试集对第一数据处理模型以及多个第二数据处理模型分别进行测试,获取第一数据处理模型分别与每个第二数据处理模型在测试集上的第一模型损失差值;其中,不同的第二数据处理模型被配置成根据初始训练集下不同的子训练集训练得到,不同的子训练集之间相差一个或多个不同的被删除样本。进一步,根据第一模型损失差值获取初始训练集内的异常样本,如果删除某个样本后得到的模型损失增加,则表明这个被删除样本对模型训练是有益的样本;如果删除某个样本后的模型损失减小,则表明这个被删除样本对模型训练是有害的样本,因此可以判定这个样本属于异常样本。进而根据异常样本对初始训练集进行样本调整,获取优化的训练集。例如:对第一模型损失差值进行由负至正的逆向排序,选取排序顺序小于等于预设顺序值的第一模型损失差值,获取这些第一模型损失差值对应的第二数据处理模型,获取在训练这些第二数据处理模型时的被删除样本(异常样本),将这些样本从初始训练集中永久删除,形成优化的训练集。最后利用优化的训练集对第一数据处理模型进行训练,以获取最终的数据处理模型。通过上述步骤,本发明实施例能够根据第一模型损失差值的变化情况,从训练集中快速且准确地筛选出异常样本,克服了现有技术中采用人工核查的方式对训练集进行样本核查,导致的核查效率低且容易发生漏检和错检的缺陷。同时,也极大地提高了模型训练效果。
进一步,在实施本发明的技术方案中,可以采用稳健统计方法对第二数据处理模型的模型参数以及其在测试集上的模型损失进行影响分析,根据影响分析的结果直接获取到第一数据处理模型与第二数据处理模型在测试集上的第一模型损失差值,无需在利用初始训练集对预设的数据处理模型训练得到的第一数据处理模型后,再利用删除训练样本后子训练集对预设的数据处理模型训练得到的第二数据处理模型,然后分别获取第一数据处理模型与第二数据处理模型在测试集上的模型损失,最后对这两个模型损失进行差值计算来得到第一模型损失差值,即省去了对第二数据处理模型的训练过程,直接通过影响分析就可以得“第一数据处理模型与第二数据处理模型在测试集上的模型损失之间的第一模型损失差值”,从而极大地提高了第一模型损失差值的获取效率,有利于对异常样本进行快速筛查。
附图说明
下面参照附图来描述本发明的具体实施方式,附图中:
图1是根据本发明的一个实施例的模型训练方法的主要步骤流程示意图;
图2是根据本发明的一个实施例的第一模型损失差值获取方法的主要步骤流程示意图;
图3是根据本发明的另一个实施例的模型训练方法的主要步骤流程示意图;
图4是根据本发明的一个实施例的第二模型损失差值获取方法的主要步骤流程示意图;
图5是根据本发明的一个实施例的模型训练装置的主要结构框图;
图6是根据本发明的另一个实施例的模型训练装置的主要结构框图;
附图标记列表:
31:第一数据处理模型获取模块;32:第一损失差值获取模块;33:训练集优化模块;34:模型训练模块;41:第三数据处理模型获取模块;42:第二损失差值获取模块;43:第三损失差值获取模块;44:对抗训练集获取模块。
具体实施方式
下面参照附图来描述本发明的一些实施方式。本领域技术人员应当理解的是,这些实施方式仅仅用于解释本发明的技术原理,并非旨在限制本发明的保护范围。
在本发明的描述中,“模块”、“处理器”可以包括硬件、软件或者两者的组合。一个模块可以包括硬件电路,各种合适的感应器,通信端口,存储器,也可以包括软件部分,比如程序代码,也可以是软件和硬件的组合。处理器可以是中央处理器、微处理器、图像处理器、数字信号处理器或者其他任何合适的处理器。处理器具有数据和/或信号处理功能。处理器可以以软件方式实现、硬件方式实现或者二者结合方式实现。非暂时性的计算机可读存储介质包括任何合适的可存储程序代码的介质,比如磁碟、硬盘、光碟、闪存、只读存储器、随机存取存储器等等。术语“A和/或B”表示所有可能的A与B的组合,比如只是A、只是B或者A和B。术语“至少一个A或B”或者“A和B中的至少一个”含义与“A和/或B”类似,可以包括只是A、只是B或者A和B。单数形式的术语“一个”、“这个”也可以包含复数形式。
这里先解释本发明涉及到的一些术语。
稳健统计方法指的是,数理统计技术领域中常规的统计方法,该方法能够描述观测值对估计量的影响。在本发明实施例中的观测值可以是训练样本权重,也可以是训练样本的扰动量,而估计量则是数据处理模型的模型损失,即本发明实施例采用稳健统计方法的目的是为了分析改变训练样本权重或增加扰动量之后,对数据处理模型的模型损失带来了哪些影响。稳健统计方法中的影响函数理论指的是,通过构建观测值与估计量之间的影响函数(Influence Function,IF),利用影响函数量化分析观测值对估计量的影响。需要说明的是,稳健统计方法及其影响函数理论均是数理统计技术领域中的常规技术,为了描述简洁,在此不再对稳健统计方法及其影响函数理论进行具体说明。
生成对抗网络模型指的是,基于生成对抗网络结构(Generative AdversarialNetworks,GAN)构建的模型。而GAN是人工智能技术领域中常规的网络结构,为了描述简洁,在此不再对GAN的具体结构、功能和训练方法进行赘述。
目前传统的样本标签方法主要是利用人工标注的方式进行标签,而采用人工标注的方式对数量级较大的训练样本如百万级别的训练样本进行标签标注,很容易标注错误,如果利用这些标签错误的噪声样本进行模型训练,会降低模型的训练效果。但是,由于训练样本的数量级较大,如果继续采用人工核查的方式对训练样本进行核查,以筛选出噪声样本,不仅费时费力,还很容易发生漏检和错检。
在本发明实施例中,可以先利用初始训练集对预设的数据处理模型(例如:数据分类模型)进行训练,获取第一数据处理模型;然后利用测试集对第一数据处理模型以及多个第二数据处理模型分别进行测试,获取第一数据处理模型分别与每个第二数据处理模型在测试集上的第一模型损失差值;其中,不同的第二数据处理模型被配置成根据初始训练集下不同的子训练集训练得到,不同的子训练集之间相差一个或多个不同的被删除样本。进一步,根据第一模型损失差值获取初始训练集内的异常样本(例如:标签标注错误的样本),如果删除某个样本后得到的模型损失增加(第一模型损失差值变大),则表明这个被删除样本对模型训练是有益的样本;如果删除某个样本后的模型损失减小(第一模型损失差值变小),则表明这个被删除样本对模型训练是有害的样本,因此可以判定这个样本属于异常样本。进而根据异常样本对初始训练集进行样本调整,获取优化的训练集。例如:对第一模型损失差值进行由负至正的逆向排序,选取排序顺序小于等于预设顺序值的第一模型损失差值,获取这些第一模型损失差值对应的第二数据处理模型,获取在训练这些第二数据处理模型时的被删除样本(异常样本),将这些样本从初始训练集中永久删除,形成优化的训练集。最后利用优化的训练集对第一数据处理模型进行训练,以获取最终的数据处理模型。通过上述步骤,本发明实施例能够根据第一模型损失差值的变化情况,从训练集中快速且准确地筛选出异常样本,克服了现有技术中采用人工核查的方式对训练集进行样本核查,导致的核查效率低且容易发生漏检和错检的缺陷。同时,也极大地提高了模型训练效果。
参阅附图1,图1是根据本发明的一个实施例的模型训练方法的主要步骤流程示意图。如图1所示,本发明实施例中的模型训练方法主要包括以下步骤:
步骤S101:利用初始训练集对预设的数据处理模型进行训练,获取第一数据处理模型。
需要说明的是,在本实施例中可以利用机器学习技术领域中常规的模型训练方法对预设的数据处理模型进行训练,为了描述简洁,在此不再对上述模型训练方法进行赘述。
步骤S102:利用测试集对第一数据处理模型以及多个第二数据处理模型分别进行测试,获取第一数据处理模型分别与每个第二数据处理模型在测试集上的第一模型损失差值。
不同的第二数据处理模型可以被配置成根据初始训练集下不同的子训练集训练得到,不同的子训练集之间相差一个或多个不同的被删除样本。一个例子:如果初始训练集包括样本1、样本2和样本3,那么同时利用样本1-3对预设的数据处理模型进行训练可以得到上述步骤S101中所述的第一数据处理模型。如果将样本1、样本2和样本3分别从初始训练集中删除,形成子训练集1-3,然后利用子训练集1-3分别对预设的数据处理模型进行训练可以得到如下表1所示的第二数据处理模型。
表1
Figure BDA0002819692220000151
第一模型损失差值指的是,第二数据处理模型在测试集上的模型损失减去第一数据处理模型在测试集上的模型损失后,得到的差值。
第一数据处理模型在测试集上的模型损失指的是,在利用测试集内每个测试样本对第一数据处理模型进行测试,得到的每个测试样本各自对应的模型损失以后,将这些模型损失求平均得到的平均损失。需要说明的是,在本实施例中在获取到每个测试样本各自对应的模型损失以后,也可以采用机器学习技术领域中其他常规的获取模型在测试集上的模型损失的方法,对这些模型损失进行计算,得到第一数据处理模型在测试集上的模型损失。
第二数据处理模型在测试集上的模型损失指的是,在利用测试集内每个测试样本对第二数据处理模型进行测试,得到的每个测试样本各自对应的模型损失以后,将这些模型损失求平均得到的平均损失。需要说明的是,在本实施例中在获取到每个测试样本各自对应的模型损失以后,也可以采用机器学习技术领域中其他常规的获取模型在测试集上的模型损失的方法,对这些模型损失进行计算,得到第二数据处理模型在测试集上的模型损失。
此外,需要说明的是,在本实施例中可以采用与获取第一数据处理模型相同的模型训练方法对预设的数据处理模型进行训练,得到每个第二数据处理模型。
参阅附图2,在本实施例中可以按照以下步骤S1021-S1024所示的方法获取第一数据处理模型分别与每个第二数据处理模型在测试集上的第一模型损失差值。
步骤S1021:获取利用初始训练集对预设的数据处理模型进行训练后,得到的第一数据处理模型的多个备选数据处理模型。
第一数据处理模型的多个备选数据处理模型指的是,在对预设的数据处理模型进行训练时通过调整模型参数和/或模型结构,得到的多个均能够满足预设的模型训练要求(例如:数据分类的准确率大于等于预设的准确率阈值)的模型。
步骤S1022:利用测试集对步骤S1021得到的多个备选数据处理模型分别进行测试,以获取最优的备选数据处理模型作为最终的第一数据处理模型并且获取最终的第一数据处理模型的第一模型参数。
需要说明的是,在本实施例中可以利用机器学习技术领域中常规的模型测试方法对每个备选数据处理模型分别进行测试,以从训练好的多个备选数据处理模型中获取最优的备选数据处理模型。
步骤S1023:根据第一模型参数,拟合利用测试集对当前被删除样本对应的第二数据处理模型的多个备选数据处理模型分别进行测试,获取最优的备选数据处理模型作为最终的第二数据处理模型并且获取最终的第二数据处理模型的第二模型参数。
第二数据处理模型的多个备选数据处理模型是利用当前被删除样本对应的子训练集对预设的数据处理模型进行训练后得到的,这些备选数据处理模型指的是在对预设的数据处理模型进行训练时通过调整模型参数和/或模型结构,得到的多个均能够满足预设的模型训练要求(例如:数据分类的准确率大于等于预设的准确率阈值)的模型。
具体而言,在本实施例中可以根据第一模型参数并且按照下式(1)所示的方法,拟合得到第二模型参数:
Figure BDA0002819692220000161
公式(1)中各参数含义如下:
Figure BDA0002819692220000162
表示拟合得到的最终的第二数据处理模型的第二模型参数,zdel表示当前被删除样本;
Figure BDA0002819692220000163
表示第一模型参数;L表示对预设的数据处理模型进行训练与测试时使用的损失函数;zi表示训练集内的第i个样本且zi=(xi,yi),xi表示样本zi中的图像样本,yi表示图像样本的标签,i=1,…,n;ε1表示预设的当前被删除样本zdel的样本权重且
Figure BDA0002819692220000164
步骤S1024:采用稳健统计方法,对步骤S1023得到的第二模型参数与最终的第二数据处理模型在测试集上的模型损失进行影响分析,以获取所述当前被删除样本对应的第一模型损失差值。
采用稳健统计方法对第二模型参数
Figure BDA0002819692220000171
与最终的第二数据处理模型在测试集上的模型损失进行影响分析,能够直接获取到“第一数据处理模型与最终的第二数据处理模型在测试集上的模型损失之间的第一模型损失差值”,而无需在利用初始训练集对预设的数据处理模型训练得到的第一数据处理模型后,再利用子训练集对预设的数据处理模型训练得到的第二数据处理模型,然后获取第一数据处理模型与第二数据处理模型在测试集上的模型损失,最后对这两个模型损失进行差值计算来得到第一模型损失差值,即省去了对第二数据处理模型的训练过程,直接通过影响分析就可以得“第一数据处理模型与最终的第二数据处理模型在测试集上的模型损失之间的第一模型损失差值”,从而极大地提高了第一模型损失差值的获取效率,有利于对异常样本进行快速筛查。
在本发明实施例的一个实施方式中,基于稳健统计方法,可以按照以下获取来获取第一模型损失差值:
基于稳健统计方法中的影响函数理论,构建下式(2)所示的第二模型参数
Figure BDA0002819692220000172
与最终的第二数据处理模型在测试集上的模型损失的影响函数,根据影响函数计算第一模型损失差值:
Figure BDA0002819692220000173
公式(2)中各参数含义如下:
Γup,loss(zdel,ztest)表示第一模型损失差值,ztest表示测试集,L表示对预设的数据处理模型进行训练与测试时使用的损失函数,
Figure BDA0002819692220000174
表示根据损失函数L计算出的损失值对模型参数θ求梯度,T表示对经
Figure BDA0002819692220000175
计算出的梯度向量进行转置;
Figure BDA0002819692220000176
表示第二模型参数;zdel表示当前被删除样本,
Figure BDA0002819692220000177
表示最终的第二数据处理模型的经验风险的Hessian矩阵且
Figure BDA0002819692220000178
经验风险指的是数据处理模型对训练集内每个样本的模型损失的平均值,其能够度量数据处理模型的训练效果。如果经验风险越小,则表明数据处理模型的训练效果越好;反之,数据处理模型的训练效果越差。需要说明的是,经验风险是机器学习技术领域中的常规技术,为了描述简洁,在此不再赘述。
此外,需要说明的是,在本实施例中可以采用数学技术领域中常规的梯度计算方法来根据损失函数L计算出的损失值对模型参数θ求梯度,为了描述简洁,在此不再对梯度计算方法的具体工作进行赘述。
下面对上述公式(2)所示的第二模型参数
Figure BDA0002819692220000181
与模型损失的影响函数的构建过程进行简单说明。
首先,在通过步骤S1022拟合得到的被删除样本zdel对应的第二模型参数
Figure BDA0002819692220000182
之后,基于稳健统计方法中的影响函数理论,构建下式(3)所示的预设的被删除样本zdel的样本权重ε1与模型参数
Figure BDA0002819692220000183
的影响函数:
Figure BDA0002819692220000184
根据公式(1)求解公式(3)可以得到下式(4):
Figure BDA0002819692220000185
在公式(4)中,
Figure BDA0002819692220000186
表示被删除样本zdel对应的第二数据处理模型的经验风险的Hessian矩阵,
Figure BDA0002819692220000187
并且
Figure BDA0002819692220000188
是正定矩阵。
通过上述公式(3)-(4)就可以估计出由于删除某个训练样本引起的模型参数的变化,而无需重新利用删除训练样本后的子训练集对数据处理模型进行训练,来获取新的模型参数。
然后,利用链式法则,分析改变某个训练样本的权重(增加样本权重ε1)对测试集进行测试的结果带来的影响,也就是评估
Figure BDA0002819692220000189
在测试集上引起的模型损失的变化。具体而言,利用链式法则构建如下式(5)所示的影响函数:
Figure BDA00028196922200001810
对公式(5)展开可以得到下式(6):
Figure BDA00028196922200001811
将公式(3)-(4)代入到公式(6),即可得到公式(2)所示影响函数的解析式。
在本实施例中通过第二模型参数
Figure BDA00028196922200001812
与最终的第二数据处理模型在测试集上的模型损失的影响函数,可以对模型参数的变化对数据处理模型的模型损失产生的影响进行量化分析,根据影响函数能够直接计算出第二模型参数
Figure BDA0002819692220000191
对第二数据处理模型在测试集上的模型损失的影响值(第一模型损失差值),从而极大地提高了获取第一模型损失差值的效率,有利于对异常样本的快速筛查。
步骤S103:根据第一模型损失差值获取初始训练集内的异常样本,并且根据异常样本对初始训练集进行样本调整,获取优化的训练集。
在本实施例中可以按照以下步骤11-步骤13获取异常样本。
步骤11:对第一模型损失差值进行由负至正的逆向排序。如果第一模型损失差值越大,则表明删除相应的样本对模型的坏处越大,也就是说,这个被删除样本是对模型训练有益的有益样本;如果第一模型损失差值越小,则表明删除相应的样本对模型的坏处越小,也就是说,这个被删除样本是对模型训练有害的有害样本。因此,可以通过对第一模型损失差值进行由负至正的逆向排序,即危害程度由大至小的顺序进行排序,以便根据逆向排序的结果快速选取出危害程度较大的异常样本。类似的,也可以通过对第一模型损失差值进行由正至负的正向排序,即有益程度由大至小的顺序进行排序,以便根据正向排序的结果快速选取出有益程度较大的有益样本。
一个例子:如果分别利用删除了样本1-10后的“子训练集1-10”训练得到了第二数据处理模型1-10,而第二数据处理模型1-10各自对应的第一模型损失差值依次是-1、-2、-3、-4、-5、1、2、3、4、5,那么对这些第一模型损失差值进行由负至正的逆向排序可以得到-5、-4、-3、-2、-1、1、2、3、4、5。
步骤12:根据步骤11获取到的根据逆向排序的结果,选取排序顺序小于等于预设顺序值的第一模型损失差值。
继续参阅上面的例子,如果预设顺序值为2,那么选取到的第一模型损失差值就是-5和-4。
步骤13:根据选取到的第一模型损失差值对应的第二数据处理模型,获取在训练第二数据处理模型时的被删除样本,将被删除样本作为异常样本。
继续参阅上面的例子,如果步骤12选取到的第一模型损失差值是-5和-4,那么可以得到异常样本是训练集中的样本5和4。
在本实施例中除了可以按照上述步骤11-13筛选异常样本,还可以根据第一模型损失差值的逆向排序的结果,选取多个样本进行样本特征分析,分析这些样本都有哪些共同特征,就可以得出数据处理模型在训练时其主要关注了这些样本的哪些特征,进一步判断这些关注的特征是否符合训练目的,如果不符合,可以有针对性的调整数据处理模型的模型参数和/或模型结构和/或训练方法。同样的,除了可以对第一模型损失差值进行逆向排序进行特征分析,还可以对第一模型损失差值进行正向排序,然后按照差值由小至大的顺序选取多个样本进行特征分析,分析这些训练样本都有哪些共同特征,就可以得出数据处理模型在训练时其主要关注了样本的哪些特征,进一步判断这些关注的特征是否符合训练目的,如果不符合,可以有针对性的调整数据处理模型的模型参数和/或模型结构和/或训练方法。
在本实施例中可以按照以下步骤21-步骤22对训练集进行样本调整。
步骤21:获取异常样本的样本标签。
步骤22:判断样本标签是否正确。
如果样本标签正确,则表明这个异常样本不适合对数据处理模型进行训练,数据处理模型无法从这个异常样本中学习到相应的能力。例如:训练数据处理模型的目的是使其能够对图像中的车辆进行分类,判断这个车辆属于机动车还是非机动车,因此训练样本的样本标签可以包括机动车和非机动车。如果获取到的异常样本是一个机动车图像且样本标签是机动车(样本标签正确),但是图像中机动车的大部分区域都被建筑物遮挡,使得数据处理模型无法从这个样本中学习到这个图像是机动车图像还是非机动车,因此需要删除这个样本。
如果样本标签错误,则直接修正这个异常样本的样本标签。继续参阅上面的例子,如果获取到的异常样本是一个机动车图像,样本标签是非机动车,显然这个异常样本的样本标签是错误的,因此将其样本标签修正为机动车即可。
步骤S104:利用优化的训练集对第一数据处理模型进行训练,以获取最终的数据处理模型。
在本实施例中,可以采用获取第一数据处理模型时使用的模型训练方法,继续对该第一数据处理模型进行训练。此外,在本实施例中也可以采用机器学习技术领域中与上述模型训练方法不同的其他常规的模型训练方法,对第一数据处理模型进行模型训练。为了描述简洁,在此不再对上述模型训练的具体过程进行赘述。
通过上述步骤S101-步骤S104,本发明实施例能够从训练集中快速且准确地筛选出异常样本,克服了现有技术中采用人工核查的方式对训练集进行样本核查,导致的核查效率低且容易发生漏检和错检的缺陷。
进一步,在本发明实施例的一个实施方式中,还可以在通过上述步骤S101-步骤S104获取到优化的训练集之后,利用这个优化的训练集生成对抗训练集,然后同时使用这个优化的训练集以及生成的对抗训练集对预设的生成对抗网络模型进行训练,以提高生成对抗网络模型的模型能力。参阅附图3,在本实施方式中可以按照以下获取S201-步骤S204获取对抗训练集。
步骤S201:利用优化后的训练集对预设的数据处理模型进行训练,获取第三数据处理模型。
需要说明的是,在本实施例中可以利用机器学习技术领域中常规的模型训练方法对预设的数据处理模型进行训练,为了描述简洁,在此不再对上述模型训练方法进行赘述。
步骤S202:利用测试集对第三数据处理模型以及多个第四数据处理模型分别进行测试,获取第三数据处理模型分别与每个第四数据处理模型在测试集上的第二模型损失差值。
不同的第四数据处理模型可以被配置成根据优化后的训练集下不同的子训练集训练得到,不同的子训练集之间相差一个或多个不同的被扰动样本。一个例子:如果初始训练集包括样本1、样本2和样本3,那么同时利用样本1-3对预设的数据处理模型进行训练可以得到上述步骤S201中所述的第三数据处理模型。如果利用对样本1增加扰动后形成的子训练集1、对训练样本2增加扰动后形成的子训练集2、对训练样本3增加扰动后形成的子训练集3,分别对预设的数据处理模型进行训练可以得到如下表2所示的第四数据处理模型。
表2
Figure BDA0002819692220000221
第二模型损失差值指的是,第四数据处理模型在测试集上的模型损失减去第三数据处理模型在测试集上的模型损失后,得到的差值。
第三数据处理模型在测试集上的模型损失指的是,在利用测试集内每个测试样本对第三数据处理模型进行测试,得到的每个测试样本各自对应的模型损失以后,将这些模型损失求平均得到的平均损失。需要说明的是,在本实施例中在获取到每个测试样本各自对应的模型损失以后,也可以采用机器学习技术领域中其他常规的获取模型在测试集上的模型损失的方法,对这些模型损失进行计算,得到第三数据处理模型在测试集上的模型损失。
第四数据处理模型在测试集上的模型损失指的是,在利用测试集内每个测试样本对第四数据处理模型进行测试,得到的每个测试样本各自对应的模型损失以后,将这些模型损失求平均得到的平均损失。需要说明的是,在本实施例中在获取到每个测试样本各自对应的模型损失以后,也可以采用机器学习技术领域中其他常规的获取模型在测试集上的模型损失的方法,对这些模型损失进行计算,得到第四数据处理模型在测试集上的模型损失。
此外,需要说明的是,在本实施例中可以采用与获取第三数据处理模型相同的模型训练方法对预设的数据处理模型进行训练,得到每个第四数据处理模型。
参阅附图4,在本实施例中可以按照以下步骤S2021-S2024所示的方法获取第二模型损失差值。
步骤S2021:获取利用优化后的训练集对预设的数据处理模型进行训练后,得到的第三数据处理模型的多个备选数据处理模型。
第三数据处理模型的多个备选数据处理模型指的是,在对预设的数据处理模型进行训练时通过调整模型参数和/或模型结构,得到的多个均能够满足预设的模型训练要求(例如:数据分类的准确率大于等于预设的准确率阈值)的模型。
步骤S2022:利用测试集对步骤S2021得到的多个备选数据处理模型分别进行测试,以获取最优的备选数据处理模型作为最终的第三数据处理模型并且获取最终的第三数据处理模型的第三模型参数。
需要说明的是,在本实施例中可以利用机器学习技术领域中常规的模型测试方法对每个第三数据处理模型进行测试,以从多个第三数据处理模型获取中最优的第三数据处理模型。
步骤S2023:根据第三模型参数,拟合利用测试集对当前被扰动样本对应的第四数据处理模型的多个备选数据处理模型分别进行测试,获取最优的备选数据处理模型作为最终的第四数据处理模型并且获取最终的第四数据处理模型的第四模型参数。
当前被扰动样本对应的多个备选数据处理模型是利用当前被扰动样本对应的子训练集对预设的数据处理模型进行训练得到的,这些备选数据处理模型指的是在对预设的数据处理模型进行训练时通过调整模型参数和/或模型结构,得到的多个均能够满足预设的模型训练要求(例如:数据分类的准确率大于等于预设的准确率阈值)的模型。
具体而言,在本实施例中可以根据第三模型参数并且按照下式(7)所示的方法,拟合得到第四模型参数:
Figure BDA0002819692220000231
公式(7)中各参数含义如下:
Figure BDA0002819692220000232
表示拟合得到的最终的第四数据处理模型的第四模型参数,zδ表示对当前被扰动样本z增加扰动量δ后形成的新的样本且zδ=(x+δ,y),x表示样本zδ中的图像样本,y表示图像样本x的标签;
Figure BDA0002819692220000233
表示第三模型参数;zi表示训练集内第i个样本且zi=(xi,yi),xi表示样本zi中的图像样本,yi表示图像样本xi的标签,i=1,...,n;ε2表示预设的当前被扰动样本z的样本权重且
Figure BDA0002819692220000234
步骤S2024:采用稳健统计方法,对第四模型参数与最终的第四数据处理模型在测试集上的模型损失进行影响分析,以获取当前被扰动样本对应的第二模型损失差值。
采用稳健统计方法对第四模型参数与最终的第四数据处理模型在测试集上的模型损失进行影响分析,能够直接获取到“第三数据处理模型与最终的第四数据处理模型在测试集上的模型损失之间的第二模型损失差值”,而无需在利用未对训练样本增加扰动的“优化的训练集”对预设的数据处理模型训练得到的第三数据处理模型后,再利用对训练样本增加扰动的“子训练集”对预设的数据处理模型训练得到的第四数据处理模型,然后分别获取第三数据处理模型与第四数据处理模型在测试集上的模型损失,最后对这两个模型损失进行差值计算来得到第二模型损失差值,即省去了对第四数据处理模型的训练过程,直接通过影响分析就可以得“第三数据处理模型与最终的第四数据处理模型在测试集上的模型损失之间的第二模型损失差值”,从而极大地提高了第二模型损失差值的获取效率,有利于对抗训练集的快速生成。
在本发明实施例的一个实施方式中,基于稳健统计方法,可以按照以下获取来获取第二模型损失差值:
基于稳健统计方法中的影响函数理论,构建下式(8)所示的第四模型参数与最终的第四模型参数在测试集上的模型损失的影响函数,根据影响函数计算第二模型损失差值:
Figure BDA0002819692220000241
公式(8)中各参数含义如下:
Γpert,loss(z,ztest)表示第二模型损失差值,ztest表示测试集,L表示对预设的数据处理模型进行训练与测试时使用的损失函数,
Figure BDA0002819692220000242
表示根据损失函数L计算出的损失值对模型参数θ求梯度,T表示对经
Figure BDA0002819692220000243
计算出的梯度向量进行转置;
Figure BDA0002819692220000244
表示第四模型参数;z表示当前被扰动样本,
Figure BDA0002819692220000245
表示最终的第四数据处理模型的经验风险的Hessian矩阵且
Figure BDA0002819692220000246
Figure BDA0002819692220000247
表示在样本z被增加扰动量δ之前以及之后根据损失函数L计算出的损失值的差值在图像样本x处对应的一阶泰勒展开式。
经验风险指的是数据处理模型对训练集内每个训练样本的模型损失的平均值,其能够度量数据处理模型的训练效果。如果经验风险越小,则表明数据处理模型的训练效果越好;反之,数据处理模型的训练效果越差。需要说明的是,经验风险是机器学习技术领域中的常规技术,为了描述简洁,在此不再赘述。
一阶泰勒展开式指的是泰勒级数(Taylor series)公式的一阶展开公式。需要说明的是,泰勒级数公式是数学技术领域中的常规技术,为了描述简洁,在此不再赘述。
需要说明的是,在本实施例中可以采用数学技术领域中常规的梯度计算方法来根据损失函数L计算出的损失值对模型参数θ求梯度,为了描述简洁,在此不再对梯度计算方法的具体工作进行赘述。
下面对上述公式(8)所示的第四模型参数与最终的第四模型参数在测试集上的模型损失的影响函数的构建过程进行简单说明。
首先,在通过步骤S2022拟合得到被扰动样本z增加扰动量δ后的第四模型参数
Figure BDA0002819692220000251
之后,基于稳健统计方法中的影响函数理论,构建下式(9)所示的扰动量δ与模型参数
Figure BDA0002819692220000252
的影响函数:
Figure BDA0002819692220000253
对公式(9)展开可以得到下式(10):
Figure BDA0002819692220000254
如果训练集内每个样本中的图像样本x是连续的且ε2非常小,那么公式(10)对任意的扰动量δ都是成立的。当扰动量δ很小时,可以用一阶梯度来近似
Figure BDA0002819692220000255
因此扰动量δ与模型参数
Figure BDA0002819692220000256
的影响函数可以近似表达为下式(11)所示的解析形式:
Figure BDA0002819692220000257
如果将被扰动样本z替换成增加扰动量δ以后的样本zδ,那么可以近似得到模型参数的变化量如下式(12)所示:
Figure BDA0002819692220000258
扰动量δ的大小在测试集上引起的模型损失的变化量,可以利用下式(13)计算得到:
Figure BDA0002819692220000261
通过求解公式(13)可以得到公式(8)所示的第四模型参数
Figure BDA0002819692220000262
与模型损失的影响函数。
当被扰动样本增加扰动量δ后,利用测试集对数据处理模型进行测试时,数据处理模型的模型损失会增大Γpert,loss(z,ztest)Tδ。因此,根据模型损失“Γpert,loss(z,ztest)Tδ”的变化趋势,可以调整扰动量δ的值,以获取最大的模型损失“Γpert,loss(z,ztest)Tδ”。进一步,在本实施例中,可以通过衡量Γpert,loss(z,ztest)的大小,来分析第三数据处理模型对训练集扰动的抗击能力。如果Γpert,loss(z,ztest)越大,则表明对训练集扰动的抗击能力越弱,训练集增加扰动后会对第三数据处理模型的模型损失产生较大影响;如果Γpert,loss(z,ztest)越小,则表明对训练集扰动的抗击能力越强,训练集增加扰动后不会对第三数据处理模型的模型损失产生较大影响。
在本实施例中通过第四模型参数
Figure BDA0002819692220000263
与模型损失的影响函数,可以对模型参数的变化对第四数据处理模型的模型损失产生的影响进行量化分析,根据影响函数能够直接计算出第四模型参数
Figure BDA0002819692220000264
对第四数据处理模型的模型损失的影响值(第二模型损失差值),从而极大地提高了获取第二模型损失差值的效率,有利于对抗训练集的快速生成。
步骤S203:根据每个第四数据处理模型各自对应的第二模型损失差值的变化趋势,调整每个第四数据处理模型各自对应的被扰动样本的扰动量,以获取每个第四数据处理模型各自对应的最大的第二模型损失差值。
一个例子:如果提高某个被扰动样本的扰动量,相应的第二模型损失差值先增大再减小,那么就先不断增加这个被扰动样本的扰动量,直至第二模型损失差值由增加变为减小,则停止增加扰动。
步骤S204:获取最大的第二模型损失差值对应的扰动量与被扰动样本,并且根据扰动量对被扰动样本进行扰动形成新的样本,以根据所述新的样本构建对抗训练集,即将增加扰动后的样本作为对抗样本,将所有的对抗训练样本组成抗训练集。
通过上述步骤S201-步骤S204,本发明实施例能够快速且准确地生成大批量的对抗样本,进而提高对抗生成网络模型的训练效率,并且使训练好的对抗生成网络模型具备较高的模型性能。
需要指出的是,尽管上述实施例中将各个步骤按照特定的先后顺序进行了描述,但是本领域技术人员可以理解,为了实现本发明的效果,不同的步骤之间并非必须按照这样的顺序执行,其可以同时(并行)执行或以其他顺序执行,这些变化都在本发明的保护范围之内。
进一步,本发明还提供了一种模型训练装置。
参阅附图5,图5是根据本发明的一个实施例的模型训练装置的主要结构框图。如图5所示,本发明实施例中的模型训练装置主要包括第一数据处理模型获取模块31、第一损失差值获取模块32、训练集优化模块33和模型训练模块34。在一些实施例中,第一数据处理模型获取模块31、第一损失差值获取模块32、训练集优化模块33和模型训练模块34中的一个或多个可以合并在一起成为一个模块。在一些实施例中,第一数据处理模型获取模块31可以被配置成利用初始训练集对预设的数据处理模型进行训练,获取第一数据处理模型。第一损失差值获取模块32可以被配置成利用测试集对第一数据处理模型以及多个第二数据处理模型分别进行测试,获取第一数据处理模型分别与每个第二数据处理模型在测试集上的第一模型损失差值。训练集优化模块33可以被配置成根据第一模型损失差值获取初始训练集内的异常样本,并且根据异常样本对初始训练集进行样本调整,获取优化的训练集。模型训练模块34可以被配置成利用优化的训练集对第一数据处理模型进行训练,以获取最终的数据处理模型。其中,不同的第二数据处理模型可以被配置成根据初始训练集下不同的子训练集训练得到,不同的子训练集之间相差一个或多个不同的被删除样本。一个实施方式中,具体实现功能的描述可以参见步骤S101-步骤S104所述。
在一个实施方式中,第一损失差值获取模块32可以包括第一备选模型获取单元、第一参数获取单元、第二参数获取单元和第一损失差值获取单元。在本实施方式中,第一备选模型获取单元可以被配置成获取利用初始训练集对预设的数据处理模型进行训练后,得到的第一数据处理模型的多个备选数据处理模型。第一参数获取单元可以被配置成利用测试集对多个备选数据处理模型分别进行测试,以获取最优的备选数据处理模型作为最终的第一数据处理模型并且获取最终的第一数据处理模型的第一模型参数。第二参数获取单元可以被配置成根据第一模型参数,拟合利用测试集对当前被删除样本对应的第二数据处理模型的多个备选数据处理模型分别进行测试,获取最优的备选数据处理模型作为最终的第二数据处理模型并且获取最终的第二数据处理模型的第二模型参数,其中,多个备选数据处理模型是利用当前被删除样本对应的子训练集对预设的数据处理模型进行训练得到的。第一损失差值获取单元可以被配置成采用稳健统计方法,对第二模型参数与最终的第二数据处理模型在测试集上的模型损失进行影响分析,以获取当前被删除样本对应的第一模型损失差值。一个实施方式中,具体实现功能的描述可以参见步骤S102所述。
在一个实施方式中,第二参数获取单元可以被进一步配置成根据第一模型参数并且按照公式(1)所示的方法,拟合第二模型参数。一个实施方式中,具体实现功能的描述可以参见步骤S102所述。
在一个实施方式中,第一损失差值获取单元可以被进一步配置成基于稳健统计方法中的影响函数理论,构建公式(2)所示的第二模型参数与最终的第二数据处理模型在测试集上的模型损失的影响函数,根据影响函数计算第一模型损失差值。一个实施方式中,具体实现功能的描述可以参见步骤S102所述。
在一个实施方式中,训练集优化模块33可以被进一步配置成执行以下操作:对第一模型损失差值进行由负至正的逆向排序;根据逆向排序的结果,选取排序顺序小于等于预设顺序值的第一模型损失差值;根据选取到的第一模型损失差值对应的第二数据处理模型,获取在训练第二数据处理模型时的被删除样本,将被删除样本作为异常样本。一个实施方式中,具体实现功能的描述可以参见步骤S103所述。
在一个实施方式中,训练集优化模块33可以被进一步配置成执行以下操作:获取异常样本的样本标签;判断样本标签是否正确;若正确,则删除异常样本;若不正确,则修正异常样本的样本标签。一个实施方式中,具体实现功能的描述可以参见步骤S103所述。
参阅附图6,在根据本发明的另一个模型训练装置的实施例中,模型训练装置还可以包括第三数据处理模型获取模块41、第二损失差值获取模块42、第三损失差值获取模块43和对抗训练集获取模块44。在一些实施例中,第三数据处理模型获取模块41、第二损失差值获取模块42、第三损失差值获取模块43和对抗训练集获取模块44中的一个或多个可以合并在一起成为一个模块。在一些实施例中,第三数据处理模型获取模块41可以被配置成利用优化后的训练集对预设的数据处理模型进行训练,获取第三数据处理模型。第二损失差值获取模块42可以被配置成利用测试集对所述第三数据处理模型以及多个第四数据处理模型分别进行测试,获取第三数据处理模型分别与每个第四数据处理模型在测试集上的第二模型损失差值;其中,不同的第四数据处理模型可以被配置成根据优化后的训练集下不同的子训练集训练得到,不同的子训练集之间相差一个或多个不同的被扰动样本。第三损失差值获取模块43可以被配置成根据每个第四数据处理模型各自对应的第二模型损失差值的变化趋势,调整每个第四数据处理模型各自对应的被扰动样本的扰动量,以获取每个第四数据处理模型各自对应的最大的第二模型损失差值。对抗训练集获取模块44可以被配置成获取最大的第二模型损失差值对应的扰动量与被扰动样本,并且根据扰动量对被扰动样本进行扰动形成新的样本,以根据所述新的样本构建所述对抗训练集。一个实施方式中,具体实现功能的描述可以参见步骤S201-步骤S204所述。
在一个实施方式中,第二损失差值获取模块42可以包括第二备选模型获取单元第三参数获取单元、第四参数获取单元和第二损失差值获取单元。在本实施方式中,第二备选模型获取单元可以被配置成获取利用优化后的训练集对预设的数据处理模型进行训练后,得到的第三数据处理模型的多个备选数据处理模型。第三参数获取单元可以被配置成利用测试集对所述多个备选数据处理模型分别进行测试,以获取最优的备选数据处理模型作为最终的第三数据处理模型并且获取最终的第三数据处理模型的第三模型参数;第四参数获取单元可以被配置成根据第三模型参数,拟合利用测试集对当前被扰动样本对应的第四数据处理模型的多个备选数据处理模型分别进行测试,获取最优的备选数据处理模型作为最终的第四数据处理模型并且获取最终的第四数据处理模型的第四模型参数,其中,多个备选数据处理模型是利用当前被扰动样本对应的子训练集对预设的数据处理模型进行训练得到的;第二损失差值获取单元可以被配置成采用稳健统计方法,对第四模型参数与最终的第四数据处理模型在测试集上的模型损失进行影响分析,以获取当前被扰动样本对应的第二模型损失差值。一个实施方式中,具体实现功能的描述可以参见步骤S202所述。
在一个实施方式中,第四参数获取单元可以被进一步配置成根据第三模型参数并且按照公式(7)所示的方法,拟合第四模型参数。一个实施方式中,具体实现功能的描述可以参见步骤S202所述。
在一个实施方式中,第二损失差值获取单元可以被进一步配置成基于稳健统计方法中的影响函数理论,构建公式(8)所示的第四模型参数与最终的第四模型参数在测试集上的模型损失的影响函数,根据影响函数计算第二模型损失差值。一个实施方式中,具体实现功能的描述可以参见步骤S202所述。
上述模型训练装置以用于执行图1-4所示的模型训练方法实施例,两者的技术原理、所解决的技术问题及产生的技术效果相似,本技术领域技术人员可以清楚地了解到,为了描述的方便和简洁,模型训练装置的具体工作过程及有关说明,可以参考模型训练方法的实施例所描述的内容,此处不再赘述。
本领域技术人员能够理解的是,本发明实现上述一实施例的方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,所述计算机程序包括计算机程序代码,所述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。所述计算机可读介质可以包括:能够携带所述计算机程序代码的任何实体或装置、介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器、随机存取存储器、电载波信号、电信信号以及软件分发介质等。需要说明的是,所述计算机可读介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减,例如在某些司法管辖区,根据立法和专利实践,计算机可读介质不包括电载波信号和电信信号。
进一步,本发明还提供了一种计算机可读存储介质。在根据本发明的一个计算机可读存储介质实施例中,计算机可读存储介质可以被配置成存储执行上述方法实施例的模型训练方法的程序,该程序可以由处理器加载并运行以实现上述模型训练方法。为了便于说明,仅示出了与本发明实施例相关的部分,具体技术细节未揭示的,请参照本发明实施例方法部分。该计算机可读存储介质可以是包括各种电子设备形成的存储装置设备,可选的,本发明实施例中计算机可读存储介质是非暂时性的计算机可读存储介质。
进一步,本发明还提供了一种控制装置。在根据本发明的一个控制装置实施例中,控制装置包括处理器和存储装置,存储装置可以被配置成存储执行上述方法实施例的模型训练方法的程序,处理器可以被配置成用于执行存储装置中的程序,该程序包括但不限于执行上述方法实施例的模型训练方法的程序。为了便于说明,仅示出了与本发明实施例相关的部分,具体技术细节未揭示的,请参照本发明实施例方法部分。该控制装置可以是包括各种电子设备形成的控制装置设备。
进一步,应该理解的是,由于各个模块的设定仅仅是为了说明本发明的***的功能单元,这些模块对应的物理器件可以是处理器本身,或者处理器中软件的一部分,硬件的一部分,或者软件和硬件结合的一部分。因此,图中的各个模块的数量仅仅是示意性的。
本领域技术人员能够理解的是,可以对***中的各个模块进行适应性地拆分或合并。对具体模块的这种拆分或合并并不会导致技术方案偏离本发明的原理,因此,拆分或合并之后的技术方案都将落入本发明的保护范围内。
至此,已经结合附图所示的一个实施方式描述了本发明的技术方案,但是,本领域技术人员容易理解的是,本发明的保护范围显然不局限于这些具体实施方式。在不偏离本发明的原理的前提下,本领域技术人员可以对相关技术特征作出等同的更改或替换,这些更改或替换之后的技术方案都将落入本发明的保护范围之内。

Claims (22)

1.一种模型训练方法,其特征在于,所述方法包括:
利用初始训练集对预设的数据处理模型进行训练,获取第一数据处理模型;
利用测试集对所述第一数据处理模型以及多个第二数据处理模型分别进行测试,获取所述第一数据处理模型分别与每个所述第二数据处理模型在所述测试集上的第一模型损失差值;
根据所述第一模型损失差值获取所述初始训练集内的异常样本,并且根据所述异常样本对所述初始训练集进行样本调整,获取优化的训练集;
利用所述优化的训练集对所述第一数据处理模型进行训练,以获取最终的数据处理模型;
其中,不同的第二数据处理模型被配置成根据所述初始训练集下不同的子训练集训练得到,所述不同的子训练集之间相差一个或多个不同的被删除样本。
2.根据权利要求1所述的模型训练方法,其特征在于,“获取所述第一数据处理模型分别与每个所述第二数据处理模型在所述测试集上的第一模型损失差值”的步骤具体包括:
获取利用所述初始训练集对所述预设的数据处理模型进行训练后,得到的所述第一数据处理模型的多个备选数据处理模型;
利用所述测试集对所述多个备选数据处理模型分别进行测试,以获取最优的备选数据处理模型作为最终的第一数据处理模型并且获取所述最终的第一数据处理模型的第一模型参数;
根据所述第一模型参数,拟合利用所述测试集对当前被删除样本对应的第二数据处理模型的多个备选数据处理模型分别进行测试,获取最优的备选数据处理模型作为最终的第二数据处理模型并且获取所述最终的第二数据处理模型的第二模型参数,其中,所述多个备选数据处理模型是利用所述当前被删除样本对应的子训练集对所述预设的数据处理模型进行训练得到的;
采用稳健统计方法,对所述第二模型参数与所述最终的第二数据处理模型在所述测试集上的模型损失进行影响分析,以获取所述当前被删除样本对应的第一模型损失差值。
3.根据权利要求2所述的模型训练方法,其特征在于,所述方法包括根据所述第一模型参数并且按照下式所示的方法,拟合得到所述第二模型参数:
Figure FDA0002819692210000021
其中,所述
Figure FDA0002819692210000022
表示拟合得到的所述最终的第二数据处理模型的第二模型参数,所述zdel表示所述当前被删除样本;所述
Figure FDA0002819692210000023
表示所述第一模型参数;所述L表示对预设的数据处理模型进行训练与测试时使用的损失函数;所述zi表示所述训练集内的第i个样本且zi=(xi,yi),xi表示样本zi中的图像样本,yi表示所述图像样本的标签,i=1,...,n;所述ε1表示预设的所述当前被删除样本zdel的样本权重且
Figure FDA0002819692210000024
4.根据权利要求2所述的模型训练方法,其特征在于,“获取所述当前被删除样本对应的第一模型损失差值”的步骤具体包括:
基于稳健统计方法中的影响函数理论,构建下式所示的所述第二模型参数与所述最终的第二数据处理模型在所述测试集上的模型损失的影响函数,根据所述影响函数计算所述第一模型损失差值:
Figure FDA0002819692210000025
其中,所述Γup,loss(zdel,ztest)表示所述第一模型损失差值,所述ztest表示所述测试集,所述L表示对预设的数据处理模型进行训练与测试时使用的损失函数,所述
Figure FDA0002819692210000026
表示根据所述损失函数L计算出的损失值对所述模型参数θ求梯度,所述T表示对经
Figure FDA0002819692210000027
计算出的梯度向量进行转置;所述
Figure FDA0002819692210000028
表示所述第二模型参数;所述zdel表示所述当前被删除样本,所述
Figure FDA0002819692210000029
表示所述最终的第二数据处理模型的经验风险的Hessian矩阵且
Figure FDA00028196922100000210
5.根据权利要求1所述的模型训练方法,其特征在于,“根据所述第一模型损失差值获取所述初始训练集内的异常样本”的步骤具体包括:
对所述第一模型损失差值进行由负至正的逆向排序;
根据逆向排序的结果,选取排序顺序小于等于预设顺序值的第一模型损失差值;
根据选取到的第一模型损失差值对应的第二数据处理模型,获取在训练所述第二数据处理模型时的被删除样本,将所述被删除样本作为异常样本。
6.根据权利要求1所述的模型训练方法,其特征在于,“根据所述异常样本对所述初始训练集进行样本调整”的步骤具体包括:
获取所述异常样本的样本标签;
判断所述样本标签是否正确;
若正确,则删除所述异常样本;
若不正确,则修正所述异常样本的样本标签。
7.根据权利要求1至6中任一项所述的模型训练方法,其特征在于,所述方法还包括通过下列方式获取对抗训练集,以便利用所述优化后的训练集与所述对抗训练集对预设的生成对抗网络模型进行训练:
利用所述优化后的训练集对所述预设的数据处理模型进行训练,获取第三数据处理模型;
利用所述测试集对所述第三数据处理模型以及多个第四数据处理模型分别进行测试,获取所述第三数据处理模型分别与每个所述第四数据处理模型在所述测试集上的第二模型损失差值;其中,不同的第四数据处理模型被配置成根据所述优化后的训练集下不同的子训练集训练得到,所述不同的子训练集之间相差一个或多个不同的被扰动样本;
根据每个所述第四数据处理模型各自对应的第二模型损失差值的变化趋势,调整每个所述第四数据处理模型各自对应的被扰动样本的扰动量,以获取每个所述第四数据处理模型各自对应的最大的第二模型损失差值;
获取所述最大的第二模型损失差值对应的扰动量与被扰动样本,并且根据所述扰动量对所述被扰动样本进行扰动形成新的样本,以根据所述新的样本构建所述对抗训练集。
8.根据权利要求7所述的模型训练方法,其特征在于,“获取所述第三数据处理模型分别与每个所述第四数据处理模型在所述测试集上的第二模型损失差值”的步骤具体包括:
获取利用所述优化后的训练集对所述预设的数据处理模型进行训练后,得到的所述第三数据处理模型的多个备选数据处理模型;
利用所述测试集对所述多个备选数据处理模型分别进行测试,以获取最优的备选数据处理模型作为最终的第三数据处理模型并且获取所述最终的第三数据处理模型的第三模型参数;
根据所述第三模型参数,拟合利用所述测试集对当前被扰动样本对应的第四数据处理模型的多个备选数据处理模型分别进行测试,获取最优的备选数据处理模型作为最终的第四数据处理模型并且获取所述最终的第四数据处理模型的第四模型参数,其中,所述多个备选数据处理模型是利用所述当前被扰动样本对应的子训练集对所述预设的数据处理模型进行训练得到的;
采用稳健统计方法,对所述第四模型参数与所述最终的第四数据处理模型在所述测试集上的模型损失进行影响分析,以获取所述当前被扰动样本对应的第二模型损失差值。
9.根据权利要求8所述的模型训练方法,其特征在于,所述方法包括根据所述第三模型参数并且按照下式所示的方法,拟合得到所述第四模型参数:
Figure FDA0002819692210000041
其中,所述
Figure FDA0002819692210000042
表示拟合得到的所述最终的第四数据处理模型的第四模型参数,所述zδ表示对当前被扰动样本z增加扰动量δ后形成的新的样本且zδ=(x+δ,y),x表示样本zδ中的图像样本,y表示所述图像样本x的标签;所述
Figure FDA0002819692210000043
表示所述第三模型参数;所述zi表示所述训练集内第i个样本且zi=(xi,yi),xi表示样本zi中的图像样本,yi表示所述图像样本xi的标签,i=1,...,n;所述ε2表示预设的所述当前被扰动样本z的样本权重且
Figure FDA0002819692210000044
10.根据权利要求8所述的模型训练方法,其特征在于,“获取所述当前被扰动样本对应的第二模型损失差值”的步骤具体包括:
基于稳健统计方法中的影响函数理论,构建下式所示的所述第四模型参数与所述最终的第四模型参数在所述测试集上的模型损失的影响函数,根据所述影响函数计算所述第二模型损失差值:
Figure FDA0002819692210000051
其中,所述Γpert,loss(z,ztest)表示所述第二模型损失差值,所述ztest表示所述测试集,所述L表示对预设的数据处理模型进行训练与测试时使用的损失函数,所述
Figure FDA0002819692210000052
表示根据所述损失函数L计算出的损失值对所述模型参数θ求梯度,所述T表示对经
Figure FDA0002819692210000053
计算出的梯度向量进行转置;所述
Figure FDA0002819692210000054
表示所述第四模型参数;所述z表示所述当前被扰动样本,所述
Figure FDA0002819692210000055
表示所述最终的第四数据处理模型的经验风险的Hessian矩阵且
Figure FDA0002819692210000056
所述
Figure FDA0002819692210000057
表示在样本z被增加扰动量δ之前以及之后根据所述损失函数L计算出的损失值的差值在图像样本x处对应的一阶泰勒展开式。
11.一种模型训练装置,其特征在于,所述装置包括:
第一数据处理模型获取模块,其被配置成利用初始训练集对预设的数据处理模型进行训练,获取第一数据处理模型;
第一损失差值获取模块,其被配置成利用测试集对所述第一数据处理模型以及多个第二数据处理模型分别进行测试,获取所述第一数据处理模型分别与每个所述第二数据处理模型在所述测试集上的第一模型损失差值;
训练集优化模块,其被配置成根据所述第一模型损失差值获取所述初始训练集内的异常样本,并且根据所述异常样本对所述初始训练集进行样本调整,获取优化的训练集;
模型训练模块,其被配置成利用所述优化的训练集对所述第一数据处理模型进行训练,以获取最终的数据处理模型;
其中,不同的第二数据处理模型被配置成根据所述初始训练集下不同的子训练集训练得到,所述不同的子训练集之间相差一个或多个不同的被删除样本。
12.根据权利要求10所述的模型训练装置,其特征在于,所述第一损失差值获取模块包括第一备选模型获取单元、第一参数获取单元、第二参数获取单元和第一损失差值获取单元;
所述第一备选模型获取单元被配置成获取利用所述初始训练集对所述预设的数据处理模型进行训练后,得到的所述第一数据处理模型的多个备选数据处理模型;
所述第一参数获取单元被配置成利用所述测试集对所述多个备选数据处理模型分别进行测试,以获取最优的备选数据处理模型作为最终的第一数据处理模型并且获取所述最终的第一数据处理模型的第一模型参数;
所述第二参数获取单元被配置成根据所述第一模型参数,拟合利用所述测试集对当前被删除样本对应的第二数据处理模型的多个备选数据处理模型分别进行测试,获取最优的备选数据处理模型作为最终的第二数据处理模型并且获取所述最终的第二数据处理模型的第二模型参数,其中,所述多个备选数据处理模型是利用所述当前被删除样本对应的子训练集对所述预设的数据处理模型进行训练得到的;
所述第一损失差值获取单元被配置成采用稳健统计方法,对所述第二模型参数与所述最终的第二数据处理模型在所述测试集上的模型损失进行影响分析,以获取所述当前被删除样本对应的第一模型损失差值。
13.根据权利要求12所述的模型训练装置,其特征在于,所述第二参数获取单元被进一步配置成根据所述第一模型参数并且按照下式所示的方法,拟合得到所述第二模型参数:
Figure FDA0002819692210000061
其中,所述
Figure FDA0002819692210000062
表示拟合得到的所述最终的第二数据处理模型的第二模型参数,所述zdel表示所述当前被删除样本;所述
Figure FDA0002819692210000063
表示所述第一模型参数;所述L表示对预设的数据处理模型进行训练与测试时使用的损失函数;所述zi表示所述训练集内的第i个样本且zi=(xi,yi),xi表示样本zi中的图像样本,yi表示所述图像样本的标签,i=1,...,n;所述ε1表示预设的所述当前被删除样本zdel的样本权重且
Figure FDA0002819692210000064
14.根据权利要求12所述的模型训练装置,其特征在于,所述第一损失差值获取单元被进一步配置成基于稳健统计方法中的影响函数理论,构建下式所示的所述第二模型参数与所述最终的第二数据处理模型在所述测试集上的模型损失的影响函数,根据所述影响函数计算所述第一模型损失差值:
Figure FDA0002819692210000071
其中,所述Γup,loss(zdel,ztest)表示所述第一模型损失差值,所述ztest表示所述测试集,所述L表示对预设的数据处理模型进行训练与测试时使用的损失函数,所述
Figure FDA0002819692210000072
表示根据所述损失函数L计算出的损失值对所述模型参数θ求梯度,所述T表示对经
Figure FDA0002819692210000073
计算出的梯度向量进行转置;所述
Figure FDA0002819692210000074
表示所述第二模型参数;所述zdel表示所述当前被删除样本,所述
Figure FDA0002819692210000075
表示所述最终的第二数据处理模型的经验风险的Hessian矩阵且
Figure FDA0002819692210000076
15.根据权利要求11所述的模型训练装置,其特征在于,所述训练集优化模块被进一步配置成执行以下操作:
对所述第一模型损失差值进行由负至正的逆向排序;
根据逆向排序的结果,选取排序顺序小于等于预设顺序值的第一模型损失差值;
根据选取到的第一模型损失差值对应的第二数据处理模型,获取在训练所述第二数据处理模型时的被删除样本,将所述被删除样本作为异常样本。
16.根据权利要求11所述的模型训练装置,其特征在于,所述训练集优化模块被进一步配置成执行以下操作:
获取所述异常样本的样本标签;
判断所述样本标签是否正确;
若正确,则删除所述异常样本;
若不正确,则修正所述异常样本的样本标签。
17.根据权利要求11至16中任一项所述的模型训练装置,其特征在于,所述装置还包括:
第三数据处理模型获取模块,其被配置成利用所述优化后的训练集对所述预设的数据处理模型进行训练,获取第三数据处理模型;
第二损失差值获取模块,其被配置成利用所述测试集对所述第三数据处理模型以及多个第四数据处理模型分别进行测试,获取所述第三数据处理模型分别与每个所述第四数据处理模型在所述测试集上的第二模型损失差值;其中,不同的第四数据处理模型被配置成根据所述优化后的训练集下不同的子训练集训练得到,所述不同的子训练集之间相差一个或多个不同的被扰动样本;
第三损失差值获取模块,其被配置成根据每个所述第四数据处理模型各自对应的第二模型损失差值的变化趋势,调整每个所述第四数据处理模型各自对应的被扰动样本的扰动量,以获取每个所述第四数据处理模型各自对应的最大的第二模型损失差值;
对抗训练集获取模块,其被配置成获取所述最大的第二模型损失差值对应的扰动量与被扰动样本,并且根据所述扰动量对所述被扰动样本进行扰动形成新的样本,以根据所述新的样本构建所述对抗训练集。
18.根据权利要求17所述的模型训练装置,其特征在于,所述第二损失差值获取模块包括第二备选模型获取单元、第三参数获取单元、第四参数获取单元和第二损失差值获取单元;
所述第二备选模型获取单元被配置成获取利用所述优化后的训练集对所述预设的数据处理模型进行训练后,得到的所述第三数据处理模型的多个备选数据处理模型;
所述第三参数获取单元被配置成利用所述测试集对所述多个备选数据处理模型分别进行测试,以获取最优的备选数据处理模型作为最终的第三数据处理模型并且获取所述最终的第三数据处理模型的第三模型参数;
所述第四参数获取单元被配置成根据所述第三模型参数,拟合利用所述测试集对当前被扰动样本对应的第四数据处理模型的多个备选数据处理模型分别进行测试,获取最优的备选数据处理模型作为最终的第四数据处理模型并且获取所述最终的第四数据处理模型的第四模型参数,其中,所述多个备选数据处理模型是利用所述当前被扰动样本对应的子训练集对所述预设的数据处理模型进行训练得到的;
所述第二损失差值获取单元被配置成采用稳健统计方法,对所述第四模型参数与所述最终的第四数据处理模型在所述测试集上的模型损失进行影响分析,以获取所述当前被扰动样本对应的第二模型损失差值。
19.根据权利要求18所述的模型训练装置,其特征在于,所述第四参数获取单元被进一步配置成根据所述第三模型参数并且按照下式所示的方法,拟合得到所述第四模型参数:
Figure FDA0002819692210000091
其中,所述
Figure FDA0002819692210000092
表示拟合得到的所述最终的第四数据处理模型的第四模型参数,所述zδ表示对当前被扰动样本z增加扰动量δ后形成的新的样本且zδ=(x+δ,y),x表示样本zδ中的图像样本,y表示所述图像样本x的标签;所述
Figure FDA0002819692210000093
表示所述第三模型参数;所述zi表示所述训练集内第i个样本且zi=(xi,yi),xi表示样本zi中的图像样本,yi表示所述图像样本xi的标签,i=1,...,n;所述ε2表示预设的所述当前被扰动样本z的样本权重且
Figure FDA0002819692210000094
20.根据权利要求18所述的模型训练装置,其特征在于,所述第二损失差值获取单元被进一步配置成基于稳健统计方法中的影响函数理论,构建下式所示的所述第四模型参数与所述最终的第四模型参数在所述测试集上的模型损失的影响函数,根据所述影响函数计算所述第二模型损失差值:
Figure FDA0002819692210000095
其中,所述Γpert,loss(z,ztest)表示所述第二模型损失差值,所述ztest表示所述测试集,所述L表示对预设的数据处理模型进行训练与测试时使用的损失函数,所述
Figure FDA0002819692210000096
表示根据所述损失函数L计算出的损失值对所述模型参数θ求梯度,所述T表示对经
Figure FDA0002819692210000097
计算出的梯度向量进行转置;所述
Figure FDA0002819692210000098
表示所述第四模型参数;所述z表示所述当前被扰动样本,所述
Figure FDA0002819692210000099
表示所述最终的第四数据处理模型的经验风险的Hessian矩阵且
Figure FDA00028196922100000910
所述
Figure FDA00028196922100000911
表示在样本z被增加扰动量δ之前以及之后根据所述损失函数L计算出的损失值的差值在图像样本x处对应的一阶泰勒展开式。
21.一种控制装置,包括处理器和存储装置,所述存储装置适于存储多条程序代码,其特征在于,所述程序代码适于由所述处理器加载并运行以执行权利要求1至10中任一项所述的模型训练方法。
22.一种计算机可读存储介质,其中存储有多条程序代码,其特征在于,所述程序代码适于由处理器加载并运行以执行权利要求1至10中任一项所述的模型训练方法。
CN202011427624.3A 2020-12-07 2020-12-07 模型训练方法、装置以及计算机可读存储介质 Pending CN112529209A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011427624.3A CN112529209A (zh) 2020-12-07 2020-12-07 模型训练方法、装置以及计算机可读存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011427624.3A CN112529209A (zh) 2020-12-07 2020-12-07 模型训练方法、装置以及计算机可读存储介质

Publications (1)

Publication Number Publication Date
CN112529209A true CN112529209A (zh) 2021-03-19

Family

ID=74996877

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011427624.3A Pending CN112529209A (zh) 2020-12-07 2020-12-07 模型训练方法、装置以及计算机可读存储介质

Country Status (1)

Country Link
CN (1) CN112529209A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113239022A (zh) * 2021-04-19 2021-08-10 浙江大学 医疗诊断缺失数据补全方法及补全装置、电子设备、介质
CN113505800A (zh) * 2021-06-30 2021-10-15 深圳市慧鲤科技有限公司 图像处理方法及其模型的训练方法和装置、设备、介质

Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10007866B2 (en) * 2016-04-28 2018-06-26 Microsoft Technology Licensing, Llc Neural network image classifier
CN108932527A (zh) * 2018-06-06 2018-12-04 上海交通大学 使用交叉训练模型检测对抗样本的方法
CN109606378A (zh) * 2018-11-19 2019-04-12 江苏大学 面向非高斯噪声环境的车辆行驶状态估计方法
US20190325738A1 (en) * 2018-04-18 2019-10-24 Here Global B.V. Lane-level geometry and traffic information
CN110378961A (zh) * 2019-09-11 2019-10-25 图谱未来(南京)人工智能研究院有限公司 模型的优化方法、关键点检测方法、装置及存储介质
CN110532880A (zh) * 2019-07-29 2019-12-03 深圳大学 样本筛选及表情识别方法、神经网络、设备及存储介质
CN110796153A (zh) * 2018-08-01 2020-02-14 阿里巴巴集团控股有限公司 一种训练样本的处理方法、装置
CN110866528A (zh) * 2019-10-28 2020-03-06 腾讯科技(深圳)有限公司 一种模型训练方法、能耗使用效率预测方法、装置和介质
CN110991657A (zh) * 2019-11-22 2020-04-10 深圳市魔数智擎人工智能有限公司 一种基于机器学习的异常样本检测方法

Patent Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10007866B2 (en) * 2016-04-28 2018-06-26 Microsoft Technology Licensing, Llc Neural network image classifier
US20190325738A1 (en) * 2018-04-18 2019-10-24 Here Global B.V. Lane-level geometry and traffic information
CN108932527A (zh) * 2018-06-06 2018-12-04 上海交通大学 使用交叉训练模型检测对抗样本的方法
CN110796153A (zh) * 2018-08-01 2020-02-14 阿里巴巴集团控股有限公司 一种训练样本的处理方法、装置
CN109606378A (zh) * 2018-11-19 2019-04-12 江苏大学 面向非高斯噪声环境的车辆行驶状态估计方法
CN110532880A (zh) * 2019-07-29 2019-12-03 深圳大学 样本筛选及表情识别方法、神经网络、设备及存储介质
CN110378961A (zh) * 2019-09-11 2019-10-25 图谱未来(南京)人工智能研究院有限公司 模型的优化方法、关键点检测方法、装置及存储介质
CN110866528A (zh) * 2019-10-28 2020-03-06 腾讯科技(深圳)有限公司 一种模型训练方法、能耗使用效率预测方法、装置和介质
CN110991657A (zh) * 2019-11-22 2020-04-10 深圳市魔数智擎人工智能有限公司 一种基于机器学习的异常样本检测方法

Non-Patent Citations (5)

* Cited by examiner, † Cited by third party
Title
PANGWEI KOH ET AL.: ""Understanding Black-box Predictions via Influence Functions"", 《ARXIV:1703.04730V2》 *
SAMYADEEP BASU ET AL.: ""Influence Functions in Deep Learning Are Fragile"", 《ARXIV:2006.14651V1》 *
朱参世 等: ""一种参数容错辨识法判别和剔除野值方法研究"", 《微计算机信息》 *
王强 等: ""基于生成式-判别式混合模型的可解释性文档分类"", 《模式识别与人工智能》 *
袁兴梅 等: ""基于RSC模型和噪声去除的半监督训练方法"", 《计算机工程与科学》 *

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113239022A (zh) * 2021-04-19 2021-08-10 浙江大学 医疗诊断缺失数据补全方法及补全装置、电子设备、介质
WO2022222026A1 (zh) * 2021-04-19 2022-10-27 浙江大学 医疗诊断缺失数据补全方法及补全装置、电子设备、介质
CN113505800A (zh) * 2021-06-30 2021-10-15 深圳市慧鲤科技有限公司 图像处理方法及其模型的训练方法和装置、设备、介质

Similar Documents

Publication Publication Date Title
Colas et al. How many random seeds? statistical power analysis in deep reinforcement learning experiments
CN110009171B (zh) 用户行为模拟方法、装置、设备及计算机可读存储介质
CN110348615B (zh) 基于蚁群优化支持向量机的电缆线路故障概率预测方法
CN107015875B (zh) 一种电子整机贮存寿命评估方法及装置
CN109115383B (zh) 冷挤压强化孔的疲劳寿命预测方法
CN112529209A (zh) 模型训练方法、装置以及计算机可读存储介质
CN112433896B (zh) 一种服务器磁盘故障预测方法、装置、设备及存储介质
CN110795780A (zh) 一种基于XGBoost算法的斜拉桥有限元修正方法
CN111967535B (zh) 一种储粮管理场景温度传感器故障诊断方法及其诊断装置
CN113837596B (zh) 一种故障确定方法、装置、电子设备及存储介质
CN110706213A (zh) 基于应变响应累积分布函数差的桥梁集群结构损伤判别方法
CN112084505A (zh) 深度学习模型恶意样本检测方法、***、设备及存储介质
CN111079348B (zh) 一种缓变信号检测方法和装置
CN112507605A (zh) 基于AnoGAN的配电网异常检测方法
CN114169460A (zh) 样本筛选方法、装置、计算机设备和存储介质
CN117330987B (zh) 基于时间的电池健康状态评估的方法、***、介质和设备
CN112613191A (zh) 电缆健康状态评估方法、装置、计算机设备和存储介质
CN110956112B (zh) 一种新的高可靠性回转支承寿命评估方法
CN112380763A (zh) 一种基于数据挖掘的堆内构件可靠性分析***及方法
Zuiev et al. Questions of radioelectronic equipment diagnostics programs efficiency analysis
CN114236272A (zh) 一种电子产品的智能检测***
CN116610484B (zh) 一种模型训练方法、故障预测方法、***、设备以及介质
CN113257329A (zh) 一种基于机器学习的存储器故障诊断方法
CN107290603B (zh) 一种产品可靠性评价方法及装置
CN117555812B (zh) 一种云平台自动化测试方法及***

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication
RJ01 Rejection of invention patent application after publication

Application publication date: 20210319