CN114707638A - 模型训练、对象识别方法及装置、设备、介质和产品 - Google Patents

模型训练、对象识别方法及装置、设备、介质和产品 Download PDF

Info

Publication number
CN114707638A
CN114707638A CN202210281873.9A CN202210281873A CN114707638A CN 114707638 A CN114707638 A CN 114707638A CN 202210281873 A CN202210281873 A CN 202210281873A CN 114707638 A CN114707638 A CN 114707638A
Authority
CN
China
Prior art keywords
model
network model
loss function
function value
data
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210281873.9A
Other languages
English (en)
Inventor
杨馥魁
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202210281873.9A priority Critical patent/CN114707638A/zh
Publication of CN114707638A publication Critical patent/CN114707638A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computational Linguistics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Evolutionary Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本公开提供了一种模型训练、对象识别方法及装置、设备、介质和产品,涉及人工智能领域,尤其涉及深度学习、计算机视觉和神经网络技术领域。具体实现方案包括:对初始网络模型的模型参数进行变换,得到变换后的模型参数;将训练样本输入基准网络模型、初始网络模型和与变换后的模型参数对应的第一中间网络模型,分别对应得到第一模型数据、第二模型数据和第三模型数据;根据第一模型数据、第二模型数据和第三模型数据,确定第二中间网络模型,第一中间网络模型包括第二中间网络模型;以及在确定未达到预设迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为下一迭代轮次中的初始网络模型。

Description

模型训练、对象识别方法及装置、设备、介质和产品
技术领域
本公开涉及人工智能领域,尤其涉及深度学习、计算机视觉和神经网络技术领域,可应用于模型训练、对象识别等场景。
背景技术
深度学习遍及人工智能应用的各个领域,网络模型训练是深度学习的核心技术。但是,在一些场景下,模型训练过程存在样本数量要求高、训练效率低、训练效果不佳的现象。
发明内容
本公开提供了一种模型训练、对象识别方法及装置、设备、介质和产品。
根据本公开的一方面,提供了一种模型训练方法,包括:对初始网络模型的模型参数进行变换,得到变换后的模型参数;将训练样本输入基准网络模型、所述初始网络模型和与所述变换后的模型参数对应的第一中间网络模型,分别对应得到第一模型数据、第二模型数据和第三模型数据;根据所述第一模型数据、所述第二模型数据和所述第三模型数据,确定第二中间网络模型,所述第一中间网络模型包括所述第二中间网络模型;以及在确定未达到预设迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为下一迭代轮次中的初始网络模型。
根据本公开的一方面,提供了一种对象识别方法,包括:获取待识别的目标数据;以及利用对象识别模型,对所述目标数据进行对象特征提取,得到由所述对象特征指示的对象识别结果,所述对象识别模型采用如上述的模型训练方法生成。
根据本公开的另一方面,提供了一种模型训练装置,包括:第一处理模块,用于对初始网络模型的模型参数进行变换,得到变换后的模型参数;第二处理模块,用于将训练样本输入基准网络模型、所述初始网络模型和与所述变换后的模型参数对应的第一中间网络模型,分别对应得到第一模型数据、第二模型数据和第三模型数据;第三处理模块,用于根据所述第一模型数据、所述第二模型数据和所述第三模型数据,确定第二中间网络模型,所述第一中间网络模型包括所述第二中间网络模型;以及第四处理模块,用于在确定未达到预设迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为下一迭代轮次中的初始网络模型。
根据本公开的一方面,提供了一种对象识别装置,包括:获取模块,用于获取待识别的目标数据;以及识别模块,用于利用对象识别模型,对所述目标数据进行对象特征提取,得到由所述对象特征指示的对象识别结果,所述对象识别模型采用如上述的模型训练装置生成。
根据本公开的另一方面,提供了一种电子设备,包括:至少一个处理器和与所述至少一个处理器通信连接的存储器。其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行上述的模型训练方法,或者执行上述的对象识别方法。
根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使所述计算机执行上述的模型训练方法,或者执行上述的对象识别方法。
根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现上述的模型训练方法,或者实现上述的对象识别方法。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1示意性示出了根据本公开一实施例的模型训练方法和装置的***架构;
图2示意性示出了根据本公开一实施例的模型训练方法的流程图;
图3示意性示出了根据本公开另一实施例的模型训练方法的流程图;
图4示意性示出了根据本公开一实施例的模型训练过程的示意图;
图5示意性示出了根据本公开一实施例的对象识别方法的流程图;
图6示意性示出了根据本公开一实施例的模型训练装置的框图;
图7示意性示出了根据本公开一实施例的对象识别装置的框图;
图8示意性示出了根据本公开实施例的用于执行模型训练电子设备的框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
在此使用的术语仅仅是为了描述具体实施例,而并非意在限制本公开。在此使用的术语“包括”、“包含”等表明了所述特征、步骤、操作和/或部件的存在,但是并不排除存在或添加一个或多个其他特征、步骤、操作或部件。
在此使用的所有术语(包括技术和科学术语)具有本领域技术人员通常所理解的含义,除非另外定义。应注意,这里使用的术语应解释为具有与本说明书的上下文相一致的含义,而不应以理想化或过于刻板的方式来解释。
在使用类似于“A、B和C等中至少一个”这样的表述的情况下,一般来说应该按照本领域技术人员通常理解该表述的含义来予以解释(例如,“具有A、B和C中至少一个的***”应包括但不限于单独具有A、单独具有B、单独具有C、具有A和B、具有A和C、具有B和C、和/或具有A、B、C的***等)。
本公开的实施例提供了一种模型训练方法。本方法包括:对初始网络模型的模型参数进行变换,得到变换后的模型参数;将训练样本输入基准网络模型、初始网络模型和与变换后的模型参数对应的第一中间网络模型,分别对应得到第一模型数据、第二模型数据和第三模型数据;根据第一模型数据、第二模型数据和第三模型数据,确定第二中间网络模型,第一中间网络模型包括第二中间网络模型;以及在确定未达到预设迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为下一迭代轮次中的初始网络模型。
图1示意性示出了根据本公开一实施例的模型训练方法和装置的***架构。需要注意的是,图1所示仅为可以应用本公开实施例的***架构的示例,以帮助本领域技术人员理解本公开的技术内容,但并不意味着本公开实施例不可以用于其他设备、***、环境或场景。
根据该实施例的***架构100可以包括请求终端101、网络102和服务器103。网络102用于在请求终端101和服务器103之间提供通信链路的介质。网络102可以包括各种连接类型,例如有线、无线通信链路或者光纤电缆等等。服务器103可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或分布式***,还可以是提供云服务、云计算、网络服务、中间件服务等基础云计算服务的云服务器。
请求终端101通过网络102与服务器103进行交互,以接收或发送数据等。请求终端101例如用于向服务器103发起模型训练请求,请求终端101例如还用于向服务器103提供用于模型训练的训练样本和基准网络模型。
服务器103可以是提供各种服务的服务器,例如可以是根据由请求终端101提供的训练样本和基准网络模型进行模型训练的后台处理服务器(仅为示例)。
例如,服务器103响应于从请求终端101接收的模型训练请求,对初始网络模型的模型参数进行变换,得到变换后的模型参数,将由请求终端101提供的训练样本输入基准网络模型、初始网络模型和与变换后的模型参数对应的第一中间网络模型,分别对应得到第一模型数据、第二模型数据和第三模型数据。服务器103还用于根据第一模型数据、第二模型数据和第三模型数据,确定第二中间网络模型,第一中间网络模型包括第二中间网络模型,以及在确定未达到预设迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为下一迭代轮次中的初始网络模型。
需要说明的是,本公开实施例所提供的模型训练方法可以由服务器103执行。相应地,本公开实施例所提供的模型训练装置可以设置于服务器103中。本公开实施例所提供的模型训练方法也可以由不同于服务器103且能够与请求终端101和/或服务器103通信的服务器或服务器集群执行。相应地,本公开实施例所提供的模型训练装置也可以设置于不同于服务器103且能够与请求终端101和/或服务器103通信的服务器或服务器集群中。
应该理解,图1中的请求终端、网络和服务器的数目仅仅是示意性的。根据实现需要,可以具有任意数目的请求终端、网络和服务器。
本公开实施例提供了一种模型训练方法,下面结合图1的***架构,参考图2~图4来描述根据本公开示例性实施方式的模型训练方法。本公开实施例的模型训练方法例如可以由图1所示的服务器103来执行。
图2示意性示出了根据本公开一实施例的模型训练方法的流程图。
如图2所示,本公开实施例的模型训练方法200例如可以包括操作S210~操作S240。
在操作S210,对初始网络模型的模型参数进行变换,得到变换后的模型参数。
在操作S220,将训练样本输入基准网络模型、初始网络模型和与变换后的模型参数对应的第一中间网络模型,分别对应得到第一模型数据、第二模型数据和第三模型数据。
在操作S230,根据第一模型数据、第二模型数据和第三模型数据,确定第二中间网络模型,第一中间网络模型包括第二中间网络模型。
在操作S240,在确定未达到预设迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为下一迭代轮次中的初始网络模型。
下面示例说明本实施例的模型训练方法的各操作的示例流程。
示例性地,初始网络模型可以是初始创建后尚未达到准确可用状态的神经网络模型,需要对初始网络模型的模型参数进行更新,以实现训练初始网络模型。可以基于预设变换参数,对初始网络模型的模型参数进行变换,得到变换后的模型参数。
例如,初始网络模型的模型参数可以包括模型网络层级结构中的卷积核参数,卷积核参数例如可以包括卷积核个数、卷积核尺寸、输入通道数、输出通道数、卷积核权重、偏置参数等内容。
基于预设变换参数,对初始网络模型的模型参数进行变换,预设变换参数例如可以包括随机变换参数和随机位移参数。示例性地,随机变换参数的取值例如可以是scale=random(0.5,2),随机位移参数的取值例如可以是offset=random(-1,1)。针对初始网络模型的模型参数m0,基于随机变换参数scale和随机位移参数offset,对模型参数m0进行变换,得到变换后的模型参数m1=scale*m0+offset。
将训练样本分别输入基准网络模型、初始网络模型和与变换后的模型参数对应的至少一个第一中间网络模型。基准网络模型例如可以是经训练的教师网络模型,初始网络模型例如可以是待训练的学生网络模型。教师网络模型和学生网络模型可以包括相同的网络层级结构,学生网络模型的模型参数规模小于教师网络模型的模型参数规模。
可以将训练样本分别输入教师网络模型、初始学生网络模型和与每一个变换后的模型参数对应的中间学生网络模型。教师网络模型可以是基于数据量充足的样本数据训练得到的成熟网络模型,教师网络模型具有较大的模型参数规模和较高的模型准确率。可以将教师网络模型输出的模型数据作为监督信息,指导调整初始学生网络模型的模型参数。
可以根据第一模型数据、第二模型数据和第三模型数据,确定第二中间网络模型,第一中间网络模型包括第二中间网络模型。第一模型数据、第二模型数据和第三模型数据中均分别包括以下至少一项:样本特征、特征关联关系、样本识别结果、分类结果等内容。
一种示例方式,可以根据第一模型数据和第二模型数据,确定针对初始网络模型的损失函数值。根据第一模型数据和由每个第一中间网络模型输出的第三模型数据,确定针对每个第一中间网络模型的损失函数值。根据针对初始网络模型的损失函数值和针对每个第一中间网络模型的损失函数值,确定第二中间网络模型。第一中间网络模型包括第二中间网络模型。
确定是否达到预设迭代终止条件,在未达到迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为下一迭代轮次中的初始网络模型。示例性地,可以基于预设网络层级结构顺序,对各网络层级结构中的模型参数依次进行变换,以实现迭代优化初始网络模型的模型参数。
在达到迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为经训练的目标网络模型。示例性地,将当前迭代轮次中的第二中间网络模型作为经训练的学生网络模型。通过将教师网络模型输出的第一模型数据作为监督信息,对初始学生网络模型的模型参数进行调整,生成达到准确可用状态的目标网络模型。
在生成目标网络模型之后,可以利用目标网络模型进行数据处理。示例性地,可以获取待处理的源数据,将待处理的源数据输入目标网络模型,以获取与源数据对应的目标数据。
根据分别由基准网络模型、初始网络模型和第一中间网络模型输出的模型数据,迭代优化初始网络模型的模型参数,以实现针对初始网络模型的训练目的。将基准网络模型的模型输出数据作为监督信息,指导调整初始网络模型的模型参数,有利于提高网络模型训练效率,降低网络模型训练的样本数据量要求,有利于快速生成准确可用的目标网络模型。
图3示意性示出了根据本公开另一实施例的模型训练方法的流程图。
如图3所示,操作S230例如可以包括操作S310~S330。
在操作S310,根据第一模型数据和第二模型数据,确定第一损失函数值。
在操作S320,根据第一模型数据和第三模型数据,确定针对每个第一中间网络模型的第二损失函数值。
在操作S330,根据第一损失函数值和第二损失函数值,确定第二中间网络模型,第一模型数据、第二模型数据和第三模型数据中均分别包括以下至少一项:样本特征、分类结果。
下面示例说明本实施例的模型训练方法的各操作的示例流程。
一种示例方式,通过将训练样本输入基准网络模型、初始网络模型和与变换后的模型参数对应的第一中间网络模型,分别对应得到第一样本特征、第二样本特征和第三样本特征。
根据第一样本特征和第二样本特征,确定针对初始网络模型的第一损失函数值。例如,可以计算第一样本特征和第二样本特征之间的Frobenius(弗罗贝尼乌斯)范数距离,简称F范数距离,得到针对初始网络模型的第一损失函数值。
根据第一样本特征和第三样本特征,确定针对每个第一中间网络模型的第二损失函数值。第一中间网络模型与变换后的模型参数之间具有映射关系,即可以根据第一样本特征和第三样本特征,确定与每个变换后的模型参数对应的第二损失函数值。
根据针对每个第一中间网络模型的第二损失函数值,在最小的第二损失函数值小于第一损失函数值的情况下,将最小的第二损失函数值对应的第一中间网络模型作为第二中间网络模型。
在确定未达到迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为下一迭代轮次中的初始网络模型。
示例性地,迭代终止条件例如可以包括以下条件中的至少之一:完成针对预设数量的模型参数的变换操作、迭代次数达到预设轮次阈值、与第一损失函数值达成收敛的第二损失函数值的个数达到预设数量阈值。迭代终止条件可以根据实际的模型训练目标进行设置,本实施例对此不进行限定。
将由基准网络模型输出的样本特征作为监督信息,指导调整待训练网络模型的模型参数,以达到训练网络模型的目的。可以在使用少量样本数据的情况下,有效抑制待训练网络模型的过拟合现象,能够有效降低模型训练的样本数据量要求,和有效提升模型训练效率。
另一分类方式,将训练样本输入基准网络模型、初始网络模型和与变换后的模型参数对应的第一中间网络模型,分别对应得到第一模型数据、第二模型数据和第三模型数据。第一模型数据、第二模型数据和第三模型数据中均分别包括类结果,分类结果指示了样本数据属于目标类别的预测概率。在用于模型训练的样本数据为样本图像的情况下,与样本图像关联的分类结果例如可以指示针对各像素的分类概率。
在基于分类结果计算损失函数值时,例如可以根据分类结果,确定样本图像中的边缘像素和非边缘像素。结合由分类结果指示的针对各像素的分类概率,根据与边缘像素和非边缘像素关联的预设权重,计算加权交叉熵损失函数值,作为与对应网络模型关联的损失函数值。
示例性地,通过将训练样本输入基准网络模型、初始网络模型和与变换后的模型参数对应的第一中间网络模型,分别对应得到第一分类结果、第二分类结果和第三分类结果。
根据第一分类结果和第二分类结果,确定针对初始网络模型的第一损失函数值。根据第一分类结果和第三分类结果,确定针对每个第一中间网络模型的第二损失函数值。
根据第一损失函数值和针对每个第一中间网络模型的第二损失函数值,确定第二中间网络模型。一种示例方式,根据针对每个第一中间网络模型的第二损失函数值,确定最小的第二损失函数值。在最小的第二损失函数值小于第一损失函数值的情况下,将最小的第二损失函数值对应的第一中间网络模型确定为第二中间网络模型。
将由基准网络模型输出的分类结果作为监督信息,指导调整待训练网络模型的模型参数,能够有效保证模型训练精度,有效降低模型训练的样本数据量要求,能够有效减少模型训练的成本消耗。
另一示例方式,可以根据针对样本数据的预设分类标签和第二分类结果,确定第三损失函数值。根据第一损失函数值和第三损失函数值,基于针对第一损失函数值和第三损失函数值的预设权重,计算与初始网络模型关联的第一综合损失函数值。
根据针对样本数据的分类标签和第三分类结果,确定与每个第一中间网络模型关联的第四损失函数值。根据第二损失函数值和第四损失函数值,基于针对第二损失函数值和第四损失函数值的预设权重,计算与每个第一中间网络模型关联的第二综合损失函数值。
根据第一综合损失函数值和第二综合损失函数值,确定第二中间网络模型。示例性地,确定最小的第二综合损失函数值,在最小的第二综合损失函数值小于与第一综合损失函数值的情况下,将最小的第二综合损失函数值对应的第一中间网络模型确定为第二中间网络模型。
另一示例方式,通过将训练样本输入基准网络模型、初始网络模型和第一中间网络模型,分别得到由基准网络模型输出的第一样本特征和第一分类结果,由初始网络模型输出的第二样本特征和第二分类结果,由第一中间网络模型输出的第三样本特征和第三分类结果。
根据第一样本特征、第二样本特征、第一分类结果和第二分类结果,确定针对初始网络模型的第一损失函数值。根据第一样本特征、第三样本特征、第一分类结果和第三分类结果,确定针对每个第一中间网络模型的第二损失函数值。根据针对初始网络模型的第一损失函数值和针对每个第一中间网络模型的第二损失函数值,确定第二中间网络模型。
一种示例方式,还可以根据第一模型数据和第二模型数据,确定与初始网络模型关联的分别基于嵌入层、中间层和输出层的损失函数值。根据第一模型数据和第三模型数据,确定与每个第一中间网络模型关联的分别基于嵌入层、中间层和输出层的损失函数值。
嵌入层可用于对输入的样本数据进行特征提取,得到初始样本特征。中间层可用于对初始样本特征进行增强,得到增强后的样本特征。此外,中间层还可用于提取样本特征之间的关联关系,得到特征关联关系,中间层例如可以是自注意力层。输出层可用于根据由中间层提取的特征关联关系和增强后的样本特征,输出针对样本数据的分类结果或回归结果。
示例性地,可以根据与初始网络模型关联的分别基于嵌入层、中间层和输出层的损失函数值,确定针对初始网络模型的第三综合损失函数值。根据与每个第一中间网络模型关联的分别基于嵌入层、中间层和输出层的损失函数值,确定针对对应第一中间网络模型的第四综合损失函数值。根据第三综合损失函数值和第四综合损失函数值,确定第二中间网络模型。
示例性地,还可以根据与初始网络模型和每个第一中间网络模型关联的基于相同模型层结构的损失函数值,确定第二中间网络模型。可以根据与初始网络模型和每个第一中间网络模型关联的基于单个相同模型层结构的损失函数值,确定第二中间网络模型。也可以根据与初始网络模型和每个第一中间网络模型关联的基于多个相同模型层结构的损失函数值,确定第二中间网络模型。
例如,可以根据与初始网络模型和每个第一中间网络模型关联的中间层损失函数值,确定第二中间网络模型,中间层损失函数值可以是根据由中间层输出的样本特征确定的。或者,根据与初始网络模型和每个第一中间网络模型关联的输出层损失函数值,确定第二中间网络模型,输出层损失函数值可以是根据由输出层输出的分类结果确定的。
基于多种类型的模型输出数据,迭代优化待训练网络模型的模型参数,能够有效抑制待训练网络模型的过拟合现象,能够有效提高模型训练的准确性。有利于实现准确可用的轻量级网络模型,能够有效缓解模型尺寸和计算量规模对模型使用场景的限制。
图4示意性示出了根据本公开一实施例的模型训练过程的示意图。
如图4所示,对初始网络模型4a的模型参数进行变换,得到变换后的模型参数。示例性地,基于预设变换参数,对初始网络模型4a的模型参数1、模型参数2、......、模型参数n分别进行变换,得到变换后的模型参数1、变换后的模型参数2、......、变换后的模型参数n。
将样本数据分别输入初始网络模型4a、基准网络模型4c和与变换后的模型参数对应的至少一个第一中间网络模型。至少一个第一中间网络模型例如包括与变换后的模型参数1对应的第一中间网络模型4b1、与变换后的模型参数2对应的第一中间网络模型4b2、......、与变换后的模型参数n对应的第一中间网络模型4bn。
根据由基准网络模型4c输出的第一模型数据和由初始网络模型4a输出的第二模型数据,确定与初始网络模型4a关联的损失函数值。根据由基准网络模型4c输出的第一模型数据和由每个第一中间网络模型(例如包括第一中间网络模型4b1、4b2、......、4bn)输出的第三模型数据,确定与对应第一中间网络模型关联的损失函数值。
根据与每个第一中间网络模型关联的损失函数值,确定最小的损失函数值。在最小的损失函数值小于与初始网络模型4a关联的损失函数值的情况下,将最小的损失函数值对应的第一中间网络模型,作为第二中间网络模型。
确定是否达到预设迭代终止条件,在确定未达到迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为下一迭代轮次中的初始网络模型。在确定达到迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为训练完成的目标网络模型。
将基准网络模型的模型输出数据作为监督信息,对初始网络模型的模型参数进行迭代调整,以达到训练网络模型的目的,能够有效改善模型训练效率,和有效降低模型训练的样本数据量要求。
图5示意性示出了根据本公开一实施例的对象识别方法的流程图。
如图5所示,本公开实施例的模型训练方法500例如可以包括操作S510~操作S520。
在操作S510,获取待识别的目标数据。
在操作S520,利用对象识别模型,对目标数据进行对象特征提取,得到由对象特征指示的对象识别结果。
示例性地,对象识别模型可以是基于教师网络模型训练得到的。教师网络模型具有较大的模型参数规模和较高的模型准确率,通过将教师网络模型输出的模型数据作为监督信息,指导调整初始学生网络模型的模型参数,得到经训练的学生网络模型。基于经训练的学生网络模型,得到对象识别模型。对象识别模型可以具有较小的模型参数规模和较高的模型准确率。
示例性地,待识别的目标数据例如可以包括待识别的图像数据、语音数据、文本数据等内容。利用对象识别模型,对目标数据进行对象特征提取,得到由对象特征指示的对象识别结果。对象特征例如可以包括对象图像特征、对象声音特征、文本语义特征等信息。
示例性地,可以利用对象识别模型,执行内容推荐操作。例如,可以利用对象识别模型,确定针对目标对象的对象特征数据。基于至少一个候选内容中的目标内容,确定针对目标内容的内容特征数据,以及根据对象特征数据和内容特征数据,得到输出结果,输出结果指示了目标对象针对目标内容的感兴趣程度。响应于输出结果满足预设条件,向目标对象推荐目标内容。
对象识别模型可以是基于教师网络模型得到的轻量级模型,轻量级模型可以有效缓解模型尺寸和计算量规模对使用场景的限制。对象识别模型例如可以部署于车载***、手机终端、智能家居、可穿戴设备等计算及存储能力受限的终端设备中。对象识别模型具有较好的计算性能,可以有效保证对象识别精度。
图6示意性示出了根据本公开一实施例的模型训练装置的框图。
如图6所示,本公开实施例的模型训练装置600例如包括第一处理模块610、第二处理模块620、第三处理模块630和第四处理模块640。
第一处理模块610,用于对初始网络模型的模型参数进行变换,得到变换后的模型参数;第二处理模块620,用于将训练样本分别输入基准网络模型、初始网络模型和与变换后的模型参数对应的至少一个第一中间网络模型;第三处理模块630,用于根据由基准网络模型输出的第一模型数据、由初始网络模型输出的第二模型数据和由每个第一中间网络模型输出的第三模型数据,确定第二中间网络模型;以及四处理模块640,用于在确定未达到预设迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为下一迭代轮次中的初始网络模型,并提示第一处理模块610执行对初始网络模型的模型参数进行变换的操作。
根据分别由基准网络模型、初始网络模型和第一中间网络模型输出的模型数据,迭代优化初始网络模型的模型参数,以实现针对初始网络模型的训练目的。将基准网络模型的模型输出数据作为监督信息,指导调整初始网络模型的模型参数,有利于提高网络模型训练效率,降低网络模型训练的样本数据量要求,有利于快速生成准确可用的目标网络模型。
根据本公开的实施例,第三处理模块包括:第一处理子模块,用于根据第一模型数据和第二模型数据,确定第一损失函数值;第二处理子模块,用于根据第一模型数据和第三模型数据,确定针对每个第一中间网络模型的第二损失函数值;以及第三处理子模块,用于根据第一损失函数值和第二损失函数值,确定第二中间网络模型;第一模型数据、第二模型数据和第三模型数据中均分别包括以下至少一项:样本特征、分类结果。
根据本公开的实施例,第三处理子模块包括:第一处理单元,用于在最小的第二损失函数值小于第一损失函数值的情况下,将最小的第二损失函数值对应的第一中间网络模型确定为第二中间网络模型。
根据本公开的实施例,第一模型数据、第二模型数据和第三模型数据中均分别包括分类结果;第三处理子模块号包括:第二处理单元,用于根据针对样本数据的预设分类标签和第二模型数据,确定第三损失函数值,以及根据第一损失函数值和第三损失函数值,确定第一综合损失函数值;第三处理单元,用于根据分类标签和第三模型数据,确定针对每个第一中间网络模型的第四损失函数值,以及根据第二损失函数值和第四损失函数值,确定针对每个第一中间网络模型的第二综合损失函数值;以及第四处理单元,用于根据第一综合损失函数值和第二综合损失函数值,确定第二中间网络模型。
根据本公开的实施例,第三处理模块包括:第四处理子模块,用于根据第一模型数据和第二模型数据,确定与初始网络模型关联的分别基于嵌入层、中间层和输出层的损失函数值;第五处理子模块,用于根据第一模型数据和第三模型数据,确定与每个第一中间网络模型关联的分别基于嵌入层、中间层和输出层的损失函数值;以及第六处理子模块,用于根据分别与初始网络模型和每个第一中间网络模型关联的基于相同模型层结构的损失函数值,确定第二中间网络模型。
图7示意性示出了根据本公开一实施例的对象识别装置的框图。
如图7所示,本公开实施例的对象识别装置700例如包括获取模块710和识别模块720。
获取模块710,用于获取待识别的目标数据;以及识别模块720,用于利用对象识别模型,对目标数据进行对象特征识别,得到由对象特征指示的对象识别结果。
对象识别模型可以是基于教师网络模型训练得到的轻量级网络模型,对象识别模型可以具有较小的模型参数规模和较高的模型准确率,有利于有效保证对象识别精度,以及有效缓解模型尺寸和计算量规模对使用场景的限制。
应该注意的是,本公开的技术方案中,所涉及的信息收集、存储、使用、加工、传输、提供和公开等处理,均符合相关法律法规的规定,且不违背公序良俗。
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图8示意性示出了根据本公开实施例的用于执行模型训练方法的电子设备的框图。
图8示出了可以用来实施本公开实施例的示例电子设备800的示意性框图。电子设备800旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图8所示,设备800包括计算单元801,其可以根据存储在只读存储器(ROM)802中的计算机程序或者从存储单元808加载到随机访问存储器(RAM)803中的计算机程序,来执行各种适当的动作和处理。在RAM 803中,还可存储设备800操作所需的各种程序和数据。计算单元801、ROM802以及RAM 803通过总线804彼此相连。输入/输出(I/O)接口805也连接至总线804。
设备800中的多个部件连接至I/O接口805,包括:输入单元806,例如键盘、鼠标等;输出单元807,例如各种类型的显示器、扬声器等;存储单元808,例如磁盘、光盘等;以及通信单元809,例如网卡、调制解调器、无线通信收发机等。通信单元809允许设备800通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元801可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元801的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行深度学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元801执行上文所描述的各个方法和处理,例如模型训练方法。例如,在一些实施例中,模型训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元808。在一些实施例中,计算机程序的部分或者全部可以经由ROM 802和/或通信单元809而被载入和/或安装到设备800上。当计算机程序加载到RAM 803并由计算单元801执行时,可以执行上文描述的模型训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元801可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行模型训练方法。
本文中以上描述的***和技术的各种实施方式可以在数字电子电路***、集成电路***、现场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上***的***(SOC)、复杂可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程***上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储***、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储***、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程模型训练装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行***、装置或设备使用或与指令执行***、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体***、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与对象的交互,可以在计算机上实施此处描述的***和技术,该计算机具有:用于向对象显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),对象可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与对象的交互;例如,提供给对象的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自对象的输入。
可以将此处描述的***和技术实施在包括后台部件的计算***(例如作为数据服务器)、或者包括中间件部件的计算***(例如,应用服务器)、或者包括前端部件的计算***(例如,具有图形对象界面或者网络浏览器的对象计算机,对象可以通过该图形对象界面或者该网络浏览器来与此处描述的***和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算***中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将***的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机***可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,也可以为分布式***的服务器,或者是结合了区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。

Claims (15)

1.一种模型训练方法,包括:
对初始网络模型的模型参数进行变换,得到变换后的模型参数;
将训练样本输入基准网络模型、所述初始网络模型和与所述变换后的模型参数对应的第一中间网络模型,分别对应得到第一模型数据、第二模型数据和第三模型数据;
根据所述第一模型数据、所述第二模型数据和所述第三模型数据,确定第二中间网络模型,其中,所述第一中间网络模型包括所述第二中间网络模型;以及
在确定未达到预设迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为下一迭代轮次中的初始网络模型。
2.根据权利要求1所述的方法,其中,所述根据所述第一模型数据、所述第二模型数据和所述第三模型数据,确定第二中间网络模型,包括:
根据所述第一模型数据和所述第二模型数据,确定第一损失函数值;
根据所述第一模型数据和所述第三模型数据,确定针对每个第一中间网络模型的第二损失函数值;以及
根据所述第一损失函数值和所述第二损失函数值,确定所述第二中间网络模型;
其中,所述第一模型数据、所述第二模型数据和所述第三模型数据中均分别包括以下至少一项:样本特征、分类结果。
3.根据权利要求2所述的方法,其中,所述根据所述第一损失函数值和所述第二损失函数值,确定所述第二中间网络模型,包括:
在最小的第二损失函数值小于所述第一损失函数值的情况下,将所述最小的第二损失函数值对应的第一中间网络模型确定为所述第二中间网络模型。
4.根据权利要求2所述的方法,其中,所述第一模型数据、所述第二模型数据和所述第三模型数据中均分别包括分类结果;所述根据所述第一损失函数值和所述第二损失函数值,确定所述第二中间网络模型,包括:
根据针对所述样本数据的预设分类标签和所述第二模型数据,确定第三损失函数值,以及根据所述第一损失函数值和所述第三损失函数值,确定第一综合损失函数值;
根据所述分类标签和所述第三模型数据,确定针对所述每个第一中间网络模型的第四损失函数值,以及根据所述第二损失函数值和所述第四损失函数值,确定针对所述每个第一中间网络模型的第二综合损失函数值;以及
根据所述第一综合损失函数值和所述第二综合损失函数值,确定所述第二中间网络模型。
5.根据权利要求1所述的方法,其中,所述根据所述第一模型数据、所述第二模型数据和所述第三模型数据,确定第二中间网络模型,包括:
根据所述第一模型数据和所述第二模型数据,确定与所述初始网络模型关联的分别基于嵌入层、中间层和输出层的损失函数值;
根据所述第一模型数据和所述第三模型数据,确定与所述每个第一中间网络模型关联的分别基于嵌入层、中间层和输出层的损失函数值;以及
根据分别与所述初始网络模型和所述每个第一中间网络模型关联的基于相同模型层结构的损失函数值,确定所述第二中间网络模型。
6.一种对象识别方法,包括:
获取待识别的目标数据;以及
利用对象识别模型,对所述目标数据进行对象特征提取,得到由所述对象特征指示的对象识别结果,其中,所述对象识别模型采用根据权利要求1至5中任一项所述的方法生成。
7.一种模型训练装置,包括:
第一处理模块,用于对初始网络模型的模型参数进行变换,得到变换后的模型参数;
第二处理模块,用于将训练样本输入基准网络模型、所述初始网络模型和与所述变换后的模型参数对应的第一中间网络模型,分别对应得到第一模型数据、第二模型数据和第三模型数据;
第三处理模块,用于根据所述第一模型数据、所述第二模型数据和所述第三模型数据,确定第二中间网络模型,其中,所述第一中间网络模型包括所述第二中间网络模型;以及
第四处理模块,用于在确定未达到预设迭代终止条件的情况下,将当前迭代轮次中的第二中间网络模型作为下一迭代轮次中的初始网络模型。
8.根据权利要求7所述的装置,其中,所述第三处理模块包括:
第一处理子模块,用于根据所述第一模型数据和所述第二模型数据,确定第一损失函数值;
第二处理子模块,用于根据所述第一模型数据和所述第三模型数据,确定针对每个第一中间网络模型的第二损失函数值;以及
第三处理子模块,用于根据所述第一损失函数值和所述第二损失函数值,确定所述第二中间网络模型;
其中,所述第一模型数据、所述第二模型数据和所述第三模型数据中均分别包括以下至少一项:样本特征、分类结果。
9.根据权利要求8所述的装置,其中,所述第三处理子模块包括:
第一处理单元,用于在最小的第二损失函数值小于所述第一损失函数值的情况下,将所述最小的第二损失函数值对应的第一中间网络模型确定为所述第二中间网络模型。
10.根据权利要求8所述的装置,其中,所述第一模型数据、所述第二模型数据和所述第三模型数据中均分别包括分类结果;所述第三处理子模块号包括:
第二处理单元,用于根据针对所述样本数据的预设分类标签和所述第二模型数据,确定第三损失函数值,以及根据所述第一损失函数值和所述第三损失函数值,确定第一综合损失函数值;
第三处理单元,用于根据所述分类标签和所述第三模型数据,确定针对所述每个第一中间网络模型的第四损失函数值,以及根据所述第二损失函数值和所述第四损失函数值,确定针对所述每个第一中间网络模型的第二综合损失函数值;以及
第四处理单元,用于根据所述第一综合损失函数值和所述第二综合损失函数值,确定所述第二中间网络模型。
11.根据权利要求7所述的装置,其中,所述第三处理模块包括:
第四处理子模块,用于根据所述第一模型数据和所述第二模型数据,确定与所述初始网络模型关联的分别基于嵌入层、中间层和输出层的损失函数值;
第五处理子模块,用于根据所述第一模型数据和所述第三模型数据,确定与所述每个第一中间网络模型关联的分别基于嵌入层、中间层和输出层的损失函数值;以及
第六处理子模块,用于根据分别与所述初始网络模型和所述每个第一中间网络模型关联的基于相同模型层结构的损失函数值,确定所述第二中间网络模型。
12.一种对象识别装置,包括:
获取模块,用于获取待识别的目标数据;以及
识别模块,用于利用对象识别模型,对所述目标数据进行对象特征提取,得到由所述对象特征指示的对象识别结果,其中,所述对象识别模型采用根据权利要求7至11中任一项所述的装置生成。
13.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1至5中任一项所述的模型训练方法,或者执行权利要求6所述的对象识别方法。
14.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行权利要求1至5中任一项所述的模型训练方法,或者执行权利要求6所述的对象识别方法。
15.一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现权利要求1至5中任一项所述的模型训练方法,或者实现权利要求6所述的对象识别方法。
CN202210281873.9A 2022-03-21 2022-03-21 模型训练、对象识别方法及装置、设备、介质和产品 Pending CN114707638A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210281873.9A CN114707638A (zh) 2022-03-21 2022-03-21 模型训练、对象识别方法及装置、设备、介质和产品

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210281873.9A CN114707638A (zh) 2022-03-21 2022-03-21 模型训练、对象识别方法及装置、设备、介质和产品

Publications (1)

Publication Number Publication Date
CN114707638A true CN114707638A (zh) 2022-07-05

Family

ID=82168584

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210281873.9A Pending CN114707638A (zh) 2022-03-21 2022-03-21 模型训练、对象识别方法及装置、设备、介质和产品

Country Status (1)

Country Link
CN (1) CN114707638A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116415687A (zh) * 2022-12-29 2023-07-11 江苏东蓝信息技术有限公司 一种基于深度学习的人工智能网络优化训练***及方法

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116415687A (zh) * 2022-12-29 2023-07-11 江苏东蓝信息技术有限公司 一种基于深度学习的人工智能网络优化训练***及方法
CN116415687B (zh) * 2022-12-29 2023-11-21 江苏东蓝信息技术有限公司 一种基于深度学习的人工智能网络优化训练***及方法

Similar Documents

Publication Publication Date Title
CN114494784A (zh) 深度学习模型的训练方法、图像处理方法和对象识别方法
CN114242113B (zh) 语音检测方法、训练方法、装置和电子设备
CN114187459A (zh) 目标检测模型的训练方法、装置、电子设备以及存储介质
CN116152833B (zh) 基于图像的表格还原模型的训练方法及表格还原方法
CN113360700A (zh) 图文检索模型的训练和图文检索方法、装置、设备和介质
CN110633717A (zh) 一种目标检测模型的训练方法和装置
CN114792355B (zh) 虚拟形象生成方法、装置、电子设备和存储介质
CN113902010A (zh) 分类模型的训练方法和图像分类方法、装置、设备和介质
CN113627536A (zh) 模型训练、视频分类方法,装置,设备以及存储介质
CN115358392A (zh) 深度学习网络的训练方法、文本检测方法及装置
CN113902696A (zh) 图像处理方法、装置、电子设备和介质
CN115456167A (zh) 轻量级模型训练方法、图像处理方法、装置及电子设备
CN115690443A (zh) 特征提取模型训练方法、图像分类方法及相关装置
CN114821063A (zh) 语义分割模型的生成方法及装置、图像的处理方法
CN114494814A (zh) 基于注意力的模型训练方法、装置及电子设备
CN114037059A (zh) 预训练模型、模型的生成方法、数据处理方法及装置
CN117746125A (zh) 图像处理模型的训练方法、装置及电子设备
CN114707638A (zh) 模型训练、对象识别方法及装置、设备、介质和产品
CN116342164A (zh) 目标用户群体的定位方法、装置、电子设备及存储介质
CN113642654B (zh) 图像特征的融合方法、装置、电子设备和存储介质
CN115482443A (zh) 图像特征融合及模型训练方法、装置、设备以及存储介质
CN113806541A (zh) 情感分类的方法和情感分类模型的训练方法、装置
CN113886543A (zh) 生成意图识别模型的方法、装置、介质及程序产品
CN114120416A (zh) 模型训练方法、装置、电子设备及介质
CN113361621A (zh) 用于训练模型的方法和装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination