CN113222149A - 模型训练方法、装置、设备和存储介质 - Google Patents
模型训练方法、装置、设备和存储介质 Download PDFInfo
- Publication number
- CN113222149A CN113222149A CN202110597904.7A CN202110597904A CN113222149A CN 113222149 A CN113222149 A CN 113222149A CN 202110597904 A CN202110597904 A CN 202110597904A CN 113222149 A CN113222149 A CN 113222149A
- Authority
- CN
- China
- Prior art keywords
- model
- data
- training
- test set
- tested
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 139
- 238000000034 method Methods 0.000 title claims abstract description 67
- 238000003860 storage Methods 0.000 title claims abstract description 16
- 238000012360 testing method Methods 0.000 claims abstract description 137
- 238000011156 evaluation Methods 0.000 claims abstract description 67
- 238000002372 labelling Methods 0.000 claims abstract description 34
- 230000015654 memory Effects 0.000 claims description 20
- 238000013136 deep learning model Methods 0.000 claims description 18
- 230000008569 process Effects 0.000 claims description 13
- 238000004590 computer program Methods 0.000 claims description 7
- 230000006870 function Effects 0.000 claims description 6
- 230000009466 transformation Effects 0.000 claims description 6
- 238000012545 processing Methods 0.000 claims description 3
- 238000000844 transformation Methods 0.000 claims description 3
- 230000000694 effects Effects 0.000 abstract description 6
- 238000005457 optimization Methods 0.000 description 10
- 238000001514 detection method Methods 0.000 description 7
- 238000013135 deep learning Methods 0.000 description 6
- 230000007246 mechanism Effects 0.000 description 6
- 230000003902 lesion Effects 0.000 description 5
- 238000004422 calculation algorithm Methods 0.000 description 4
- 238000004364 calculation method Methods 0.000 description 4
- 238000010586 diagram Methods 0.000 description 4
- 238000010801 machine learning Methods 0.000 description 4
- 238000013473 artificial intelligence Methods 0.000 description 3
- 238000009826 distribution Methods 0.000 description 2
- 238000000605 extraction Methods 0.000 description 2
- 239000005338 frosted glass Substances 0.000 description 2
- 238000003384 imaging method Methods 0.000 description 2
- 238000003058 natural language processing Methods 0.000 description 2
- 230000004044 response Effects 0.000 description 2
- 239000007787 solid Substances 0.000 description 2
- 206010056342 Pulmonary mass Diseases 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000007405 data analysis Methods 0.000 description 1
- 238000007418 data mining Methods 0.000 description 1
- 230000007423 decrease Effects 0.000 description 1
- 201000010099 disease Diseases 0.000 description 1
- 208000037265 diseases, disorders, signs and symptoms Diseases 0.000 description 1
- 230000009977 dual effect Effects 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000007667 floating Methods 0.000 description 1
- 238000011835 investigation Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 239000000463 material Substances 0.000 description 1
- 238000010295 mobile communication Methods 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000008707 rearrangement Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
- 238000004088 simulation Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/04—Inference or reasoning models
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- Computational Linguistics (AREA)
- Mathematical Physics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种模型训练方法、装置、设备和存储介质。模型训练方法包括:基于训练数据集进行模型训练,得到中间模型;基于预设数据获取类型,从待标注数据集中获取数据作为待测试数据,并基于当前中间模型生成待测试数据的模型推理结果;基于动态测试集和固定测试集分别对模型推理结果进行测试评估;若固定测试集的评估结果不满足达标条件,则基于动态测试集的评估结果调整预设数据获取类型,并基于调整后的调整预设数据获取类型对应的待测试数据对中间模型进行迭代训练,直到固定测试集的评估结果满足达标条件,得到目标模型。实现在小样本标注数据驱动下,进行动态的模型评估,快速进行模型迭代,进而快速产生符合标注的模型的效果。
Description
技术领域
本发明实施例涉及深度学习技术,尤其涉及一种模型训练方法、装置、设备和存储介质。
背景技术
深度学习作为机器学习领域重要的研究方向,它被引入机器学习使其更接近于人工智能这一最初的目标。深度学习建模的过程,一般从原始数据出发,使用给定的多层算法和初始参数及初始权重,进行数据分析,并与标准结果相比对,寻找产生差距的因素,调整相关的参数及权重再进行新一轮的模拟,最终使得计算结果与实际结果总公差达到最小的运算过程。在深度学习的实践应用中,若需要将其结合到日常业务中去,需要解决的一个基本问题,即网络模型的构建。
而人工智能技术绝大多数计算都属于有监督学习计算,深度学习在医疗应用场景中的医学文本NLP、数据挖掘、机器学习、影像分类、病灶分割等技术等中都有重要的应用,在这些场景中,如何基于尽可能少的标注数据,进行动态的模型评估,快速进行模型迭代,从而缩短产生符合标注的模型的时间是一个需要解决的问题。
目前模型评估和迭代的方法主要分为如下三类:1)基于注意力机制或设置忽略区的模型迭代方法,主要是将训练数据输入快速成像模型,通过N个多颗粒度注意力模块根据图像的多尺度信息和注意力机制对所述训练数据进行特征提取,融合注意力模块提取到的特征图,根据梯度更新进行成像的训练。2)基于参数合并的模型优化方法,主要是通过带有优化合并参数的优化深度学习模型对待处理数据进行数据处理,节省深度学习模型中的额外计算开销,减少推理计算时间和响应延迟。3)基于预设范围的模型优化方法,主要是通过获取目标深度学习模型中目标网络层的初始输出值和目标网络层的最大输出分布,计算目标网络层的目标浮点值,在初始输出值超出第一预设范围情况下,基于输出缩放系数对初始输出值进行转换,得到初始输出值得目标输出值,根据目标输出值对目标深度学习模型进行优化。
然而现有的基于注意力机制或设置忽略区的模型迭代方法,主要采用了注意力机制进行模型迭代,注意力机制本身存在一些缺点,如在自然语言处理中使用注意力机制,没法捕捉位置信息,即没法学习序列中的顺序关系。现有的基于参数合并的模型优化方法,该类方法主要在深度学习模型优化过程中,将模型对应的卷积与批归一化参数按照优化合并方式进行合并,这是一种模型参数层面的优化方式,并不能解决在基于少量数据在短时间内快速迭代的问题。现有的基于预设范围的模型优化方法,该类方法主要以人为设置范围作为模型优化的思路,由于对范围设定的依赖性导致在实际的模型迭代和优化中存在许多类间与类内变化,会需要大量的标注数据,不能灵活的利用数据,并且不能动态地做到模型的评估。
发明内容
本发明提供一种模型训练方法、装置、设备和存储介质,以实现在小样本标注数据驱动下,进行动态的模型评估,快速进行模型迭代,进而快速产生符合标注的模型。
第一方面,本发明实施例提供了一种模型训练方法,包括:
获取训练数据集,其中,所述训练数据集包括初始标注数据;
基于所述训练数据集进行模型训练,得到中间模型;
基于预设数据获取类型,从待标注数据集中获取数据作为待测试数据,并基于当前中间模型生成所述待测试数据的模型推理结果,基于所述待测试数据、所述待测试数据的模型推理结果以及标注信息更新动态测试集;
基于所述动态测试集和固定测试集分别对所述模型推理结果进行测试评估;
若所述固定测试集的评估结果不满足达标条件,则基于所述动态测试集的评估结果调整预设数据获取类型,并基于调整后的调整预设数据获取类型对应的待测试数据对所述中间模型进行迭代训练,直到所述固定测试集的评估结果满足所述达标条件,得到目标模型。
在本发明的可选实施例中,若所述固定测试集的评估结果不满足达标条件,所述方法还包括:
提取所述动态测试集和所述固定测试集中的困难样本,并将所述困难样本确定为迭代训练的待测试数据,其中,所述困难样本为所述动态测试集和所述固定测试集中模型推理结果与对应标注信息的差值大于预设值的样本。
在本发明的可选实施例中,在基于包括所述困难样本的待测试数据进行训练过程中,损失函数为:
其中,ti是第i个样本xi的标注信息,pi是模型预测xi预测正确的概率,r是一个超参数。
在本发明的可选实施例中,所述进行模型训练包括:
确定未标注数据的数据复杂度,基于所述数据复杂度将各所述待测试数据分配至对应等级的标注对象;
接收各所述标注对象反馈的待测试数据的标注信息。
在本发明的可选实施例中,所述确定训练数据集中各待测试数据的数据复杂度包括:
将所述未标注数据输入至深度学习模型中,得到测试值,并基于如下公式对所述测试值确定数据复杂度:
di:图像的复杂度,Pi:深度学习模型的测试值;m:把未标注数据做m种几何变换;J和k:分别是“m种几何变换”其中的一种变换的序号,所以最大是m。y:模型训练的数据的类别数。
在本发明的可选实施例中,在基于所述动态测试集和固定测试集分别对所述模型推理结果进行测试评估之后,还包括:
将迭代过程中的各中间模型及其评估结果存储至模型库中。
在本发明的可选实施例中,所述确定目标模型,包括:
在所述模型库中,基于各中间模型的评估结果确定目标模型;
在所述目标模型不满足所述固定测试集的达标条件时,继续执行对所述目标模型的迭代训练。
在本发明的可选实施例中,所述评估结果包括中间模型的精准度和召回率,所述基于各中间模型的评估结果确定目标模型,包括:
基于各中间模型的精准度和召回率确定模型选择参数F1值;
将所述F1值最大的中间模型确定为目标模型。
在本发明的可选实施例中,在所述目标模型不满足所述固定测试集的达标条件时,所述方法还包括:
获取模型调整操作,并执行所述模型调整操作,所述模型调整操作包括模型指标调整、训练数据调整、测试数据调整、训练方法调整中的一项或多项。
第二方面,本发明实施例还提供了一种模型训练装置,该模型训练装置包括:
获取模块,用于获取训练数据集,其中,所述训练数据集包括初始标注数据;
得到模块,用于基于所述训练数据集进行模型训练,得到中间模型;
生成模块,用于基于预设数据获取类型,从待标注数据集中获取数据作为待测试数据,并基于当前中间模型生成所述待测试数据的模型推理结果,基于所述待测试数据、所述待测试数据的模型推理结果以及标注信息更新动态测试集;
评估模块,用于基于所述动态测试集和固定测试集分别对所述模型推理结果进行测试评估;
调整模块,用于若所述固定测试集的评估结果不满足达标条件,则基于所述动态测试集的评估结果调整预设数据获取类型,并基于调整后的调整预设数据获取类型对应的待测试数据对所述中间模型进行迭代训练,直到所述固定测试集的评估结果满足所述达标条件,得到目标模型。
第三方面,本发明实施例还提供了一种计算机设备,该计算机设备包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如本发明任一实施例所述的模型训练方法。
第四方面,本发明实施例还提供了一种计算机可读存储介质,计算机可读存储介质该其上存储有计算机程序,该程序被处理器执行时实现如本发明任一实施例所述的模型训练方法。
本发明通过动态测试集对模型进行动态的评估,通过固定的测试集对模型进行常规的评估,在固定测试集的评估结果不满足达标条件时基于动态测试集的评估结果调整预设数据获取类型,这样可以节约不必要的同源数据标注,节约标注与训练成本,解决小样本标注数据驱动下无法快速产生符合标注的模型的问题,实现在小样本标注数据驱动下,进行动态的模型评估,快速进行模型迭代,进而快速产生符合标注的模型的效果。
附图说明
图1为本发明实施例一提供的一种模型训练方法的流程图;
图2是本发明实施例二提供的一种模型训练装置的流程框图;
图3为本发明实施例三提供的一种计算机设备的结构示意图。
具体实施方式
下面结合附图和实施例对本发明作进一步的详细说明。可以理解的是,此处所描述的具体实施例仅仅用于解释本发明,而非对本发明的限定。另外还需要说明的是,为了便于描述,附图中仅示出了与本发明相关的部分而非全部结构。
实施例一
图1为本发明实施例一提供的一种模型训练方法的流程图,本实施例可适用于医学影像标注的情况,训练得到的模型可为医学影像的病灶检测模型,通过对电子病历中的医学影像进行标注,能够检测该医学影像是否对应有某类疾病。该方法可以由模型训练装置来执行,具体包括如下步骤:
S110、获取训练数据集,其中,所述训练数据集包括初始标注数据。
其中,训练数据集指初始用来对模型进行训练的数据集合。例如,该模型是医学影像的病灶检测模型,训练数据集即为医学影像。根据定义好的业务场景、数据来源、标注规则后,需准备模型训练的初始标注数据,当业务场景为医院的病灶检测时,初始标注数据可为已分类好的历史医学影像,即历史医学影像中医学影像对应的病灶类型是已得知的,已在医学影像上标注对应的病灶类型。初始数据量大小需结合业务响应周期以及人力、物力进行评定。当本实施例适用于医学影像标注时,数据来源可为电子病历中的医学影像。
此外,当本实施例适用于信用评分时,业务场景可为银行对申请人的贷款评估,通过信用评分模型对申请人的信用进行评分,当申请人的信用高于预设值的时候,才对申请人发放贷款。此时数据来源为申请人的历史征信情况,例如***逾期记录、花呗逾期记录以及逾期次数等。初始标注集即为历史申请人的历史征信情况对应的信用评分。
S120、基于所述训练数据集进行模型训练,得到中间模型。
其中,在人工智能中,面对大量用户输入的数据,如果要在杂乱无章的内容准确、容易地识别,输出我们期待输出的图像/语音,并不是那么容易的,因此算法就显得尤为重要,算法就是我们所说的模型。训练指为识别高识别率的目标,使用大数据,找出最优配置参数的过程。此处的基于训练数据集进行模型训练,指利用训练数据集里的数据,对模型的参数值进行不断的调整,最后在结果统计取一个各方比较均衡、识别率较高的一组参数值,这组参数值,就是我们训练后得到的结果,即为中间模型。
进行模型训练时,利用技术选型得到的网络结构,如果是检测任务可以选择fast-rcnn系列,yolo系列或者自构建结构;如果是其他任务,均可选对应模型结构进行模型训练的工作。
例如,当业务场景为医院的病灶检测时,该中间模型是医学影像的病灶检测模型。当业务场景为银行对申请人的贷款评估时,该中间模型是信用评分模型。
S130、基于预设数据获取类型,从待标注数据集中获取数据作为待测试数据,并基于当前中间模型生成所述待测试数据的模型推理结果,基于所述待测试数据、所述待测试数据的模型推理结果以及标注信息更新动态测试集。
其中,数据有不同的类型,比如当该模型为医学影像的病灶检测模型时,数据类型即为病灶的类型,例如肺结节分类中有毛玻璃结节、实性结节等,毛玻璃结节类型。当然,根据应用场景的不同,数据的类型也会相应的不同,此处只是举例说明,不涉及具体的数据类型限定。
基于预设数据获取类型,从待标注数据集里选择相应类型的数据。通过当前中间模型对待测试数据进行标注,模型推理结果即为当前中间模型对待测试数据进行预测得到的结果。动态测试集指人工对当前批待测试数据进行标注后生成的标注数据集,标注信息即为人工所进行的选择标注。
S140、基于所述动态测试集和固定测试集分别对所述模型推理结果进行测试评估。
其中,固定测试集不会跟随模型的迭代而改变,固定测试集的模型指标是由最初业务确定的常量。固定测试集是对待标注数据集的全部数据进行标注后得到的标注值,动态测试集的对当前批待测试数据进行标注后得到的标注值。进行测试评估是指通过将模型推理结果与固定测试集和动态测试集中的标注值进行比较,得到模型推理结果中预测为正确和预测为错误的样本,进而可以根据预设正确和预设错误的样本数分别计算出在固定测试集中的精准度和召回率等指标,根据这些指标的具体数值得出当前中间模型的模型优劣程度。通过固定测试集和动态测试集分别对模型推理结果进行测试评估是双重保障,提升模型训练效果。
S150、若所述固定测试集的评估结果不满足达标条件,基于所述动态测试集的评估结果调整预设数据获取类型,并基于调整后的调整预设数据获取类型对应的待测试数据对所述中间模型进行迭代训练,直到所述固定测试集的评估结果满足所述达标条件,得到目标模型。
其中,评估结果指将固定测试集中的数据与模型推理结果相比较所得出的差异情况,例如根据固定测试集中的数据与模型推理结果算出当前中间模型的精准度和召回率。达标条件指模型推理结果与固定测试集的差异情况达到了预设值,例如达标条件指精准度和召回率大于95%,当当前中间模型的模型推理结果与固定测试集相比精准度和召回率大于95%,则说明评估结果满足达标条件。
根据动态测试集的评估结果决定下一轮从待标注数据集中选取的数据主要类型,这样可以节约不必要的同源数据标注,节约标注与训练成本。数据同源是传统机器学习依赖的基本假设,即训练数据和测试数据服从相同分布。
可选的,若所述固定测试集的评估结果不满足达标条件,所述方法还包括:提取所述动态测试集和所述固定测试集中的困难样本,并将所述困难样本确定为迭代训练的待测试数据,其中,所述困难样本为所述动态测试集和所述固定测试集中模型推理结果与对应标注信息的差值大于预设值的样本。
具体的,模型推理结果与对应标注信息具有差别的样本也分为困难样本(HardSample)和容易样本(Easy Sample)。困难样本指模型推理结果与对应标注信息的差值较大的样本,容易样本指模型推理结果与对应标注信息的差值较小的样本。例如,预设值为10%,标注信息[1,0,0],而模型推理结果为[0.3,0.3,0.4]时,两者相差大于10%,则此时该样本是困难样本。而预测出[0.98,0.01,0.01]时,与两者相差小于10%,则此时该样本为容易样本。
可选的,在基于包括所述困难样本的待测试数据进行训练过程中,损失函数为:
其中,ti是第i个样本xi的标注信息,pi是模型预测xi预测正确的概率,r是一个超参数。
标注信息指人工对第i个样本的标注值,由上述可知,容易样本指模型推理结果与对应标注信息的差值较小的样本,因此,pi若大于预设值,说明模型推理结果与对应的标注信息的差值较小,即一个容易样本(Easy Sample),则图片就比较小,再r次方一下,会更小。所以Loss在损失函数上就可以使困难样本(Hard Sample)在loss中贡献更大,从而使得训练效果对困难样本学的更好。此处的优化,能够使得困难样本对loss贡献更大,从而达到对loss改进的效果。
可选的,进行模型训练还包括:确定训练数据集中未标注数据的数据复杂度,基于所述数据复杂度将各所述待测试数据分配至对应等级的标注对象,接收各所述标注对象反馈的待测试数据的标注信息。
其中,在模型训练的初期,需要对深度学习模型进行人工培训,通过主动学习,进行相应任务的数据标定。深度学习模型指根据初始标注数据生成中间模型的过程中所迭代的模型。此处的未标注数据指深度学习模型在进行相应任务的数据标定时未标注的数据。将未标注数据再次进行分配至标注对象进行标注,能够找出深度学习模型漏标的数据。利用持续的微调,使用最新标注的数据加上原先标注的数据中被当前模型预测错误的数据,也就是被“遗忘”的数据。能解决在持续微调中常发生的模型遗忘问题,提高数据的利用率,微调模型需要的计算资源少,收敛速度比从头训练快。
标注对象指对未标注数据进行人工标注的人员对应的电子设备,例如标注人员为技师和专家,标注对象指不同技师和专家的电子设备,等级可以是1,2,3等,专家对应1级,技师对应2级。还可以是技师级,专家级等。通过将数据复杂度高的未标注数据发送给等级高的标注对象进行标注,有利于快速标注。
可选的,所述确定训练数据集中各待测试数据的数据复杂度包括:将所述未标注数据输入至深度学习模型中,得到测试值,并基于如下公式对所述测试值确定数据复杂度:
di:数据的复杂度,Pi:深度学习模型的测试值;m:把未标注数据做m种几何变换;J和k:分别是“m种几何变换”其中的一种变换的序号,所以最大是m;y:模型训练的数据的类别数,即数据有多少种类别。
其中在模型训练的初期,需要对深度学习模型进行人工培训,通过主动学习,进行相应任务的数据标定。深度学习模型指根据初始标注数据生成中间模型的过程中所迭代的模型。通过上述公式,能够方便的计算出数据的复杂度。
可选的,在基于所述动态测试集和固定测试集分别对所述模型推理结果进行测试评估之后,还包括:将迭代过程中的各中间模型及其评估结果存储至模型库中。
其中,评估结果指将固定测试集中的数据与模型推理结果相比较所得出的差异情况,例如根据固定测试集中的数据与模型推理结果算出当前中间模型的精准度和召回率。
在上述实施例的基础上,所述确定目标模型,包括:在所述模型库中,基于各中间模型的评估结果确定目标模型;在所述目标模型不满足所述固定测试集的达标条件时,继续执行对所述目标模型的迭代训练。
其中,目标模型为评估结果最优的中间模型,当目标模型不满足所述固定测试集的达标条件时,说明该目标模型并不能实行交付,需要再次进行训练。
在上述实施例的基础上,所述评估结果包括中间模型的精准度和召回率,所述基于各中间模型的评估结果确定目标模型,包括:
基于各中间模型的精准度和召回率确定模型选择参数F1值。
将所述F1值最大的中间模型确定为目标模型。
其中,TP(a):实际正类预测为正类的数量;FN(b):实际正类预测为负类的数量;FP(c):实际负类预测为正类的数量;TN(d):实际负类预测为负类的数量;T=True,F=False,表示是否预测正确;P=Positive,N=Negative,表示预测结果是正类还是负类。
召回率(Recall),查全率,正确预测为正的占全部实际为正的比例。召回率的计算公式为:recall=TP/(TP+FN)。
精准度(Precision):预测出来的正样本中,正确的有多少。精准度的公式为:precision=TP/(TP+FP)。
F1值(H-mean值),为了能够评价不同算法的优劣,在Precision和Recall的基础上提出了F1值的概念,来对Precision和Recall进行整体评价。F1值为算数平均数除以几何平均数,且越大越好,将Precision和Recall的上述公式带入会发现,当F1值小时,TruePositive相对增加,而false相对减少,即Precision和Recall都相对增加,即F1对Precision和Recall都进行了加权。F1值的计算公式为:
其中,F1值越大,说明模型越优,将F1值最大的中间模型确定为目标模型,能够使目标模型为模型库中的最优模型。
在上述实施例的基础上,在所述目标模型不满足所述固定测试集的达标条件时,所述方法还包括:
获取模型调整操作,并执行所述模型调整操作,所述模型调整操作包括模型指标调整、训练数据调整、测试数据调整、训练方法调整中的一项或多项。
其中,当训练数据集中的所有待测试数据均测试完成时,由于此时已无数据进行测试,从而模型迭代过程结束,此时倘若目标模型仍不满足固定测试集的达标条件,说明可能是模型本身有问题,例如模型指标选取不当、训练数据选取不当、测试数据选取不当以及训练方法不当等。模型调整操作可通过人工输入,通过获取模型调整操作并执行,能够对模型进行调整,防止持续自动在错误的模型中进行训练效率低下的情况发生。相应的,此时继续执行对所述目标模型的迭代训练具体包括继续执行对进行模型调整操作后的目标模型进行迭代训练。
本实施例的技术方案,通过动态测试集对模型进行动态的评估,通过固定的测试集对模型进行常规的评估,在固定测试集的评估结果不满足达标条件时基于动态测试集的评估结果调整预设数据获取类型,这样可以节约不必要的同源数据标注,节约标注与训练成本,解决小样本标注数据驱动下无法快速产生符合标注的模型的问题,实现在小样本标注数据驱动下,进行动态的模型评估,快速进行模型迭代,进而快速产生符合标注的模型的效果。
实施例二
图2是本发明实施例二提供的一种模型训练装置的流程框图,如图2所示,本发明实施例的模型训练装置具体可以包括如下模块:
获取模块61,用于获取训练数据集,其中,所述训练数据集包括初始标注数据。
得到模块62,用于基于所述训练数据集进行模型训练,得到中间模型。
生成模块63,用于基于预设数据获取类型,从待标注数据集中获取数据作为待测试数据,并基于当前中间模型生成所述待测试数据的模型推理结果,基于所述待测试数据、所述待测试数据的模型推理结果以及标注信息更新动态测试集。
评估模块64,用于基于所述动态测试集和固定测试集分别对所述模型推理结果进行测试评估。
调整模块65,用于若所述固定测试集的评估结果不满足达标条件,则基于所述动态测试集的评估结果调整预设数据获取类型,并基于调整后的调整预设数据获取类型对应的待测试数据对所述中间模型进行迭代训练,直到所述固定测试集的评估结果满足所述达标条件,得到目标模型。
在本发明可选实施例中,若所述固定测试集的评估结果不满足达标条件,该模型训练装置还包括:
提取模块,用于提取所述动态测试集和所述固定测试集中的困难样本,并将所述困难样本确定为迭代训练的待测试数据,其中,所述困难样本为所述动态测试集和所述固定测试集中模型推理结果不同于对应标注信息的样本。
在本发明可选实施例中,该模型训练装置还包括:
确定模块,用于确定训练数据集中各待测试数据的数据复杂度,基于所述数据复杂度将各所述待测试数据分配至对应等级的标注对象。
接收模块,用于接收各所述标注对象反馈的待测试数据的标注信息。
在本发明可选实施例中,确定模块,还用于将所述未标注数据输入至深度学习模型中,得到测试值,并基于如下公式对所述测试值确定数据复杂度:
在本发明可选实施例中,该模型训练装置还包括:
存储模块,用于将迭代过程中的各中间模型及其评估结果存储至模型库中。
在本发明可选实施例中,该模型训练装置还包括:
模型确定模块,用于在所述模型库中,基于各中间模型的评估结果确定目标模型。
迭代执行模块,用于在所述目标模型不满足所述固定测试集的达标条件时,继续执行对所述目标模型的迭代训练。
在本发明可选实施例中,所述评估结果包括中间模型的精准度和召回率,所述模型确定模块还包括:
选择子模块,用于基于各中间模型的精准度和召回率确定模型选择参数F1值。
确定子模块,用于将所述F1值最大的中间模型确定为目标模型。
在本发明可选实施例中,该模型训练装置还包括:
模型调整模块,用于获取模型调整操作,并执行所述模型调整操作,所述模型调整操作包括模型指标调整、训练数据调整、测试数据调整、训练方法调整中的一项或多项。
实施例三
图3为本发明实施例三提供的一种计算机设备的结构示意图,如图3所示,
包括存储器71、处理器72及存储在存储器71上并可在处理器72上运行的计算机程序,所述处理器72执行所述程序时实现如上述任一实施例所述的模型训练方法。
该计算机设备包括处理器72、存储器71、输入装置73和输出装置74;计算机设备中处理器72的数量可以是一个或多个,图3中以一个处理器72为例;计算机设备中的处理器72、存储器71、输入装置73和输出装置74可以通过总线或其他方式连接,图3中以通过总线连接为例。
存储器71作为一种计算机可读存储介质,可用于存储软件程序、计算机可执行程序以及模块,如本发明实施例中模型训练方法的对应的程序指令/模块(例如,模型训练装置中的获取模块、得到模块、生成模块、评估模块和调整模块)。
存储器71可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作***、至少一个功能所需的应用程序;存储数据区可存储根据终端的使用所创建的数据等。此外,存储器71可以包括高速随机存取存储器71,还可以包括非易失性存储器71,例如至少一个磁盘存储器71件、闪存器件、或其他非易失性固态存储器71件。在一些实例中,存储器71可进一步包括相对于处理器72远程设置的存储器71,这些远程存储器71可以通过网络连接至设备/终端/服务器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
输入装置73可用于接收输入的数字或者字符信息,以及产生与设备的电力用户设置以及功能控制有关的键信号输入。
输出装置74可包括显示可用于接收输入的数字或字符信息,以及产生与屏等显示设备。
处理器72通过运行存储在存储器71中的软件程序、指令以及模块,从而执行计算机设备的各种功能应用以及数据处理,即实现上述的模型训练方法。
实施例四
本发明实施例四还提供一种包含计算机可执行指令的存储介质,所述计算机可执行指令在由计算机处理器执行时用于执行一种模型训练方法,该方法包括:
获取训练数据集,其中,所述训练数据集包括初始标注数据。
基于所述训练数据集进行模型训练,得到中间模型。
基于预设数据获取类型,从待标注数据集中获取数据作为待测试数据,并基于当前中间模型生成所述待测试数据的模型推理结果,基于所述待测试数据、所述待测试数据的模型推理结果以及标注信息更新动态测试集。
基于所述动态测试集和固定测试集分别对所述模型推理结果进行测试评估。
若所述固定测试集的评估结果不满足达标条件,则基于所述动态测试集的评估结果调整预设数据获取类型,并基于调整后的调整预设数据获取类型对应的待测试数据对所述中间模型进行迭代训练,直到所述固定测试集的评估结果满足所述达标条件,得到目标模型。
当然,本发明实施例所提供的一种包含计算机可执行指令的存储介质,其计算机可执行指令不限于如上所述的方法操作,还可以执行本发明任意实施例所提供的模型训练方法中的相关操作.
通过以上关于实施方式的描述,所属领域的技术人员可以清楚地了解到,本发明可借助软件及必需的通用硬件来实现,当然也可以通过硬件实现,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品可以存储在计算机可读存储介质中,如计算机的软盘、只读存储器(Read-Only Memory,ROM)、随机存取存储器(RandomAccess Memory,RAM)、闪存(FLASH)、硬盘或光盘等,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述的方法。
值得注意的是,上述模型训练装置的实施例中,所包括的各个单元和模块只是按照功能逻辑进行划分的,但并不局限于上述的划分,只要能够实现相应的功能即可;另外,各功能单元的具体名称也只是为了便于相互区分,并不用于限制本发明的保护范围。
注意,上述仅为本发明的较佳实施例及所运用技术原理。本领域技术人员会理解,本发明不限于这里所述的特定实施例,对本领域技术人员来说能够进行各种明显的变化、重新调整和替代而不会脱离本发明的保护范围。因此,虽然通过以上实施例对本发明进行了较为详细的说明,但是本发明不仅仅限于以上实施例,在不脱离本发明构思的情况下,还可以包括更多其他等效实施例,而本发明的范围由所附的权利要求范围决定。
Claims (12)
1.一种模型训练方法,其特征在于,包括:
获取训练数据集,其中,所述训练数据集包括初始标注数据;
基于所述训练数据集进行模型训练,得到中间模型;
基于预设数据获取类型,从待标注数据集中获取数据作为待测试数据,并基于当前中间模型生成所述待测试数据的模型推理结果,基于所述待测试数据、所述待测试数据的模型推理结果以及标注信息更新动态测试集;
基于所述动态测试集和固定测试集分别对所述模型推理结果进行测试评估;
若所述固定测试集的评估结果不满足达标条件,则基于所述动态测试集的评估结果调整预设数据获取类型,并基于调整后的调整预设数据获取类型对应的待测试数据对所述中间模型进行迭代训练,直到所述固定测试集的评估结果满足所述达标条件,得到目标模型。
2.根据权利要求1所述的模型训练方法,其特征在于,若所述固定测试集的评估结果不满足达标条件,所述方法还包括:
提取所述动态测试集和所述固定测试集中的困难样本,并将所述困难样本确定为迭代训练的待测试数据,其中,所述困难样本为所述动态测试集和所述固定测试集中模型推理结果与对应标注信息的差值大于预设值的样本。
4.根据权利要求1所述的模型训练方法,其特征在于,所述进行模型训练包括:
确定未标注数据的数据复杂度,基于所述数据复杂度将各所述待测试数据分配至对应等级的标注对象;
接收各所述标注对象反馈的待测试数据的标注信息。
6.根据权利要求1所述的模型训练方法,其特征在于,在基于所述动态测试集和固定测试集分别对所述模型推理结果进行测试评估之后,还包括:
将迭代过程中的各中间模型及其评估结果存储至模型库中。
7.根据权利要求6所述的模型训练方法,其特征在于,所述确定目标模型,包括:
在所述模型库中,基于各中间模型的评估结果确定目标模型;
在所述目标模型不满足所述固定测试集的达标条件时,继续执行对所述目标模型的迭代训练。
8.根据权利要求7所述的模型训练方法,其特征在于,所述评估结果包括中间模型的精准度和召回率,所述基于各中间模型的评估结果确定目标模型,包括:
基于各中间模型的精准度和召回率确定模型选择参数F1值;
将所述F1值最大的中间模型确定为目标模型。
9.根据权利要求7所述的模型训练方法,其特征在于,在所述目标模型不满足所述固定测试集的达标条件时,所述方法还包括:
获取模型调整操作,并执行所述模型调整操作,所述模型调整操作包括模型指标调整、训练数据调整、测试数据调整、训练方法调整中的一项或多项。
10.一种模型训练装置,其特征在于,包括:
获取模块,用于获取训练数据集,其中,所述训练数据集包括初始标注数据;
得到模块,用于基于所述训练数据集进行模型训练,得到中间模型;
生成模块,用于基于预设数据获取类型,从待标注数据集中获取数据作为待测试数据,并基于当前中间模型生成所述待测试数据的模型推理结果,基于所述待测试数据、所述待测试数据的模型推理结果以及标注信息更新动态测试集;
评估模块,用于基于所述动态测试集和固定测试集分别对所述模型推理结果进行测试评估;
调整模块,用于若所述固定测试集的评估结果不满足达标条件,则基于所述动态测试集的评估结果调整预设数据获取类型,并基于调整后的调整预设数据获取类型对应的待测试数据对所述中间模型进行迭代训练,直到所述固定测试集的评估结果满足所述达标条件,得到目标模型。
11.一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现如权利要求1-9中任一所述的模型训练方法。
12.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行时实现如权利要求1-9中任一所述的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110597904.7A CN113222149B (zh) | 2021-05-31 | 2021-05-31 | 模型训练方法、装置、设备和存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110597904.7A CN113222149B (zh) | 2021-05-31 | 2021-05-31 | 模型训练方法、装置、设备和存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113222149A true CN113222149A (zh) | 2021-08-06 |
CN113222149B CN113222149B (zh) | 2024-04-26 |
Family
ID=77099322
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110597904.7A Active CN113222149B (zh) | 2021-05-31 | 2021-05-31 | 模型训练方法、装置、设备和存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113222149B (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113688989A (zh) * | 2021-08-31 | 2021-11-23 | 中国平安人寿保险股份有限公司 | 深度学习网络加速方法、装置、设备及存储介质 |
CN114067109A (zh) * | 2022-01-13 | 2022-02-18 | 安徽高哲信息技术有限公司 | 谷物检测方法及检测设备、存储介质 |
CN114741269A (zh) * | 2022-04-14 | 2022-07-12 | 网思科技股份有限公司 | 一种推理***业务性能评估的方法 |
WO2023060954A1 (zh) * | 2021-10-14 | 2023-04-20 | 北京百度网讯科技有限公司 | 数据处理与数据质检方法、装置及可读存储介质 |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109829375A (zh) * | 2018-12-27 | 2019-05-31 | 深圳云天励飞技术有限公司 | 一种机器学习方法、装置、设备及*** |
WO2020119075A1 (zh) * | 2018-12-10 | 2020-06-18 | 平安科技(深圳)有限公司 | 通用文本信息提取方法、装置、计算机设备和存储介质 |
-
2021
- 2021-05-31 CN CN202110597904.7A patent/CN113222149B/zh active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2020119075A1 (zh) * | 2018-12-10 | 2020-06-18 | 平安科技(深圳)有限公司 | 通用文本信息提取方法、装置、计算机设备和存储介质 |
CN109829375A (zh) * | 2018-12-27 | 2019-05-31 | 深圳云天励飞技术有限公司 | 一种机器学习方法、装置、设备及*** |
Non-Patent Citations (1)
Title |
---|
翟翔宇;杨风暴;吉琳娜;李书强;吕红亮;: "标准化全连接残差网络空战目标威胁评估", 火力与指挥控制, no. 06, 15 June 2020 (2020-06-15) * |
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113688989A (zh) * | 2021-08-31 | 2021-11-23 | 中国平安人寿保险股份有限公司 | 深度学习网络加速方法、装置、设备及存储介质 |
CN113688989B (zh) * | 2021-08-31 | 2024-04-19 | 中国平安人寿保险股份有限公司 | 深度学习网络加速方法、装置、设备及存储介质 |
WO2023060954A1 (zh) * | 2021-10-14 | 2023-04-20 | 北京百度网讯科技有限公司 | 数据处理与数据质检方法、装置及可读存储介质 |
CN114067109A (zh) * | 2022-01-13 | 2022-02-18 | 安徽高哲信息技术有限公司 | 谷物检测方法及检测设备、存储介质 |
CN114067109B (zh) * | 2022-01-13 | 2022-04-22 | 安徽高哲信息技术有限公司 | 谷物检测方法及检测设备、存储介质 |
CN114741269A (zh) * | 2022-04-14 | 2022-07-12 | 网思科技股份有限公司 | 一种推理***业务性能评估的方法 |
CN114741269B (zh) * | 2022-04-14 | 2022-09-23 | 网思科技股份有限公司 | 一种推理***业务性能评估的方法 |
Also Published As
Publication number | Publication date |
---|---|
CN113222149B (zh) | 2024-04-26 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
EP3989119A1 (en) | Detection model training method and apparatus, computer device, and storage medium | |
CN113222149B (zh) | 模型训练方法、装置、设备和存储介质 | |
CN109993102B (zh) | 相似人脸检索方法、装置及存储介质 | |
US11861925B2 (en) | Methods and systems of field detection in a document | |
EP3968179A1 (en) | Place recognition method and apparatus, model training method and apparatus for place recognition, and electronic device | |
EP3989104A1 (en) | Facial feature extraction model training method and apparatus, facial feature extraction method and apparatus, device, and storage medium | |
CN112784778B (zh) | 生成模型并识别年龄和性别的方法、装置、设备和介质 | |
CN111127364B (zh) | 图像数据增强策略选择方法及人脸识别图像数据增强方法 | |
CN110288017B (zh) | 基于动态结构优化的高精度级联目标检测方法与装置 | |
CN112926654A (zh) | 预标注模型训练、证件预标注方法、装置、设备及介质 | |
Freytag et al. | Labeling examples that matter: Relevance-based active learning with gaussian processes | |
CN110781970A (zh) | 分类器的生成方法、装置、设备及存储介质 | |
CN111178537A (zh) | 一种特征提取模型训练方法及设备 | |
WO2015146113A1 (ja) | 識別辞書学習システム、識別辞書学習方法および記録媒体 | |
CN112949519A (zh) | 目标检测方法、装置、设备及存储介质 | |
CN113192028B (zh) | 人脸图像的质量评价方法、装置、电子设备及存储介质 | |
CN111161238A (zh) | 图像质量评价方法及装置、电子设备、存储介质 | |
CN114299340A (zh) | 模型训练方法、图像分类方法、***、设备及介质 | |
CN113569018A (zh) | 问答对挖掘方法及装置 | |
CN117371511A (zh) | 图像分类模型的训练方法、装置、设备及存储介质 | |
CN115795355B (zh) | 一种分类模型训练方法、装置及设备 | |
CN115482436B (zh) | 图像筛选模型的训练方法、装置以及图像筛选方法 | |
CN111582404B (zh) | 内容分类方法、装置及可读存储介质 | |
CN115762721A (zh) | 一种基于计算机视觉技术的医疗影像质控方法和*** | |
CN115410250A (zh) | 阵列式人脸美丽预测方法、设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |