CN117217368A - 预测模型的训练方法、装置、设备、介质及程序产品 - Google Patents

预测模型的训练方法、装置、设备、介质及程序产品 Download PDF

Info

Publication number
CN117217368A
CN117217368A CN202311137508.1A CN202311137508A CN117217368A CN 117217368 A CN117217368 A CN 117217368A CN 202311137508 A CN202311137508 A CN 202311137508A CN 117217368 A CN117217368 A CN 117217368A
Authority
CN
China
Prior art keywords
data
target domain
network
sample data
prediction
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202311137508.1A
Other languages
English (en)
Inventor
刘洋
赵子敬
魏斯桐
陈庆超
彭宇新
李德辉
杨一帆
王巨宏
钟学丹
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Peking University
Tencent Technology Shenzhen Co Ltd
Original Assignee
Peking University
Tencent Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Peking University, Tencent Technology Shenzhen Co Ltd filed Critical Peking University
Priority to CN202311137508.1A priority Critical patent/CN117217368A/zh
Publication of CN117217368A publication Critical patent/CN117217368A/zh
Pending legal-status Critical Current

Links

Landscapes

  • Image Analysis (AREA)

Abstract

本申请公开了一种预测模型的训练方法、装置、设备、介质及程序产品,涉及人工智能领域。该方法包括:对目标域样本数据进行变换处理,得到目标域变换数据,并通过第一编码器对目标域变换数据进行编码处理,得到编码数据;基于编码数据和目标域样本数据对第一编码器进行训练,得到第二编码器;基于第二编码器更新第一预测网络,得到教师网络和学生网络;基于源域样本数据和目标域样本数据通过教师网络输出的伪标签,对学生网络进行训练,得到第二预测网络。对教师网络和学生网络的共享编码器进行目标域特征学习,使得教师网络和学生网络能够提前适应目标域特征,从而提高产生的伪标签质量,提升对学生网络的训练效果。

Description

预测模型的训练方法、装置、设备、介质及程序产品
技术领域
本申请实施例涉及人工智能技术领域,特别涉及一种预测模型的训练方法、装置、设备、介质及程序产品。
背景技术
领域自适应(Domain Adaptation)是迁移学习中的一个重要部分,旨在将知识从标签丰富的领域(源域)转移到相关但标签稀疏的领域(目标域),领域自适应能够减少标签标注工作,提高模型的训练效率。
相关技术中,通过构建教师-学生网络来达到领域自适应的目的。首先,在源域数据和源域标签上训练一个候选模型,使得候选模型具有对源域数据的处理能力;其次,将候选模型拓展为一个学生网络和一个教师网络,教师网络处理目标域数据以产生伪标签,学生网络处理目标域数据和源域数据得到预测结果;最后,通过源域标签对源域数据预测结果进行监督,通过伪标签对学生网络的目标域数据预测结果进行监督,从而训练得到一个适应源域和目标域的数据的通用模型。
然而,相关技术中产出的伪标签质量较差,导致对学生网络的训练效果较差。
发明内容
本申请实施例提供了一种预测模型的训练方法、装置、设备、介质及程序产品,能够提高教师网络产出的伪标签的质量,从而提升对学生网络的训练效果,所述技术方案如下:
一方面,提供了一种预测模型的训练方法,所述方法包括:
获取第一预测网络,所述第一预测网络是通过源域数据在第一预测任务中训练得到的网络,所述第一预测网络中包括第一编码器;
对目标域样本数据进行变换处理,得到目标域变换数据,并通过所述第一编码器对所述目标域变换数据进行编码处理,得到编码数据;
基于所述编码数据和所述目标域样本数据对所述第一编码器进行训练,得到第二编码器;
基于所述第二编码器更新所述第一预测网络,得到教师网络和学生网络;
基于源域样本数据和所述目标域样本数据通过所述教师网络输出的伪标签,对所述学生网络进行训练,得到第二预测网络,所述第二预测网络用于在所述第一预测任务中对源域数据或者目标域数据进行预测。
另一方面,提供了一种预测模型的训练装置,所述装置包括:
获取模块,用于获取第一预测网络,所述第一预测网络是通过源域数据在第一预测任务中训练得到的网络,所述第一预测网络中包括第一编码器;
变换模块,用于对目标域样本数据进行变换处理,得到目标域变换数据,并通过所述第一编码器对所述目标域变换数据进行编码处理,得到编码数据;
训练模块,用于基于所述编码数据和所述目标域样本数据对所述第一编码器进行训练,得到第二编码器;
更新模块,用于基于所述第二编码器更新所述第一预测网络,得到教师网络和学生网络;
所述训练模块,还用于基于源域样本数据和所述目标域样本数据通过所述教师网络输出的伪标签,对所述学生网络进行训练,得到第二预测网络,所述第二预测网络用于在所述第一预测任务中对源域数据或者目标域数据进行预测。
另一方面,提供了一种计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现如上述实施例中任一所述预测模型的训练方法。
另一方面,提供了一种计算机可读存储介质,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现如上述实施例中任一所述的预测模型的训练方法。
另一方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述实施例中任一所述的预测模型的训练方法。
本申请实施例提供的技术方案带来的有益效果至少包括:
获取通过源域数据训练的用于执行第一预测任务的第一预测网络后,对第一预测网络中的第一编码器进行跨域特征学习,即将变换后的目标域样本数据输入第一编码器,用原有的目标域样本数据作为标签对第一编码器进行监督训练,得到第二编码器;后续,根据训练好的第二编码器构建教师网络和学生网络,并基于源域样本数据和目标域样本数据,通过教师网络输出的伪标签对学生网络进行训练,得到在第一预测任务中具备源域数据和目标域数据预测能力的第二预测模型。其中,训练学生网络在第一预测任务中适应目标域数据之前,对教师网络和学生网络的共享编码器进行目标域特征学习,使得教师网络和学生网络能够提前适应目标域特征,从而提高产生的伪标签质量,提升对学生网络的训练效果。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请一个示例性实施例提供的实施环境的示意图;
图2是本申请一个示例性实施例提供的预测模型的训练方法的流程图;
图3是本申请另一个示例性实施例提供的预测模型的训练方法的流程图;
图4是本申请又一个示例性实施例提供的预测模型的训练方法的流程图;
图5是本申请一个示例性实施例提供的预测模型的训练流程示意图;
图6是本申请一个示例性实施例提供的源域图像的示意图;
图7是本申请一个示例性实施例提供的目标域图像未使用本申请实施例提供的方法的检测效果示意图;
图8是本申请一个示例性实施例提供的目标域图像使用本申请实施例提供的方法的检测效果示意图;
图9是本申请一个示例性实施例提供的目标域图像真实的检测效果示意图;
图10是本申请一个示例性实施例提供的预测模型的训练装置的结构框图;
图11是本申请另一个示例性实施例提供的预测模型的训练装置的结构框图;
图12是本申请一个示例性实施例提供的计算机设备的结构框图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述,显然,所描述的实施例是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
本申请中术语“第一”、“第二”等字样用于对作用和功能基本相同的相同项或相似项进行区分,应理解,“第一”、“第二”之间不具有逻辑或时序上的依赖关系,也不对数量和执行顺序进行限定。
首先,针对本申请实施例中涉及的名词进行简单介绍。
人工智能(Artificial Intelligence,AI):是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用***。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、预训练模型技术、操作/交互***、机电一体化等。其中,预训练模型又称大模型、基础模型,经过微调后可以广泛应用于人工智能各大方向下游任务。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
机器学习(Machine Learning,ML):是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、示教学习等技术。
无监督学习(Unsupervised Learning,UL):又称非监督式学习,是机器学习的一种方法,没有给定事先标记过的训练范例,自动对输入的资料进行分类或分群。
领域自适应(Domain Adaptation,DA):简称跨域,旨在从源数据分布中学习,在不同(但相关)的目标数据分布上构建出表现良好的模型。
目标检测(Object Detection,OD):检测出图像中具有语义信息的特定类别目标(物体),确定它们的类别和位置。
相关技术中,通过构建教师-学生网络来达到领域自适应的目的。首先,在源域数据和源域标签上训练一个候选模型,使得候选模型具有对源域数据的处理能力;其次,将候选模型扩展为一个学生网络和一个教师网络,教师网络处理目标域数据以产生伪标签,学生网络处理目标域数据和源域数据得到预测结果;最后,通过源域标签对源域数据预测结果进行监督,通过伪标签对学生网络的目标域数据预测结果进行监督,从而训练得到一个适应源域和目标域的数据的通用模型。然而,相关技术中产出的伪标签质量较差,导致对学生网络的训练效果较差。
在本申请实施例中,训练学生网络在第一预测任务中适应目标域数据之前,对教师网络和学生网络的共享编码器进行目标域特征学习,使得教师网络和学生网络能够提前适应目标域特征,从而提高产生的伪标签质量,提升对学生网络的训练效果。应用本申请实施例提供的预测模型的训练方法训练得到的预测模型可以应用于多种预测任务,例如:图像分类任务、目标检测任务、语义分割任务、实例分割任务等,本申请实施例对此不加以限定。
其次,对本申请实施例中涉及的实施环境进行说明,本申请实施例提供的预测模型的训练方法可以由终端单独执行实现,也可以由服务器执行实现,或者由终端和服务器通过数据交互实现,本申请实施例对此不加以限定。可选地,以终端和服务器交互执行预测模型的训练方法为例进行说明。
示意性的,请参考图1,该实施环境中涉及终端110、服务器120,终端110和服务器120之间通过通信网络130连接。可选地,通信网络130为有线网络或者无线网络。
在一些实施例中,终端110通过通信网络130向服务器120发送训练数据,该训练数据包括源域样本数据和目标域样本数据,其中,源域样本数据是有标签的数据,而目标域样本数据是没有标签的数据。
在一些实施例中,服务器120中存储有第一预测网络,该第一预测网络用于在第一预测任务中对源域数据进行预测;或者,服务器120中存储有待训练的预测网络,服务器120接收到训练数据后,通过源域样本数据和源域标签对待训练的预测网络进行训练,得到第一预测网络。
在获取第一预测网络后,服务器120首先对接收到的目标域样本数据进行变换处理,得到目标域变换数据,并通过第一预测网络中的第一编码器对目标域变换数据进行编码处理,得到编码数据;其次,基于编码数据和目标域样本数据对第一编码器进行训练,得到第二编码器;然后,基于第二编码器更新第一预测网络,得到教师网络和学生网络;最后,基于源域样本数据和目标域样本数据通过教师网络输出的伪标签,对学生网络进行训练,得到第二预测网络。
上述终端110包括但不限于手机、平板电脑、便携式膝上笔记本电脑、智能语音交互设备、智能家电、车载终端等移动终端,也可以实现为台式电脑等。
可选地,终端110中安装具有目标任务预测功能的应用程序,其中,目标任务包括图像分类任务、目标检测任务、语义分割任务、实例分割任务等中的至少一种,本申请实施例对此不加以限定。示意性的,上述应用程序包括即时通讯应用程序、新闻资讯应用程序、综合搜索引擎应用程序、社交应用程序、游戏应用程序、购物应用程序、地图导航应用程序等;或者,该应用程序实现为依赖宿主应用程序的小程序,宿主应用程序可以实现为上述任意程序,本申请实施例对此不加以限定。
可选地,服务器120训练得到第二预测网络后,通过第二预测网络为终端110中具有目标任务预测功能的应用程序提供后台计算服务;或者,服务器120将训练得到的第二预测网络发送至终端110,则终端110能够单独实现目标任务预测功能。
值得注意的是,服务器120能够是独立的物理服务器,也能够是多个物理服务器构成的服务器集群或者分布式***,还能够是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、CDN(Content Delivery Network,内容分发网络)以及大数据和人工智能平台等基础云计算服务的云服务器。
其中,云技术(Cloud Technology)是指在广域网或局域网内将硬件、软件、网络等系列资源统一起来,实现数据的计算、储存、处理和共享的一种托管技术。云技术基于云计算商业模型应用的网络技术、信息技术、整合技术、管理平台技术、应用技术等的总称,可以组成资源池,按需所用,灵活便利。云计算技术将变成重要支撑。技术网络***的后台服务需要大量的计算、存储资源,如视频网站、图片类网站和更多的门户网站。伴随着互联网行业的高度发展和应用,将来每个物品都有可能存在自己的识别标志,都需要传输到后台***进行逻辑处理,不同程度级别的数据将会分开处理,各类行业数据皆需要强大的***后盾支撑,只能通过云计算来实现。可选地,服务器120还可以实现为区块链***中的节点。
需要进行说明的是,本申请在收集用户的相关数据之前以及在收集用户的相关数据的过程中,都可以显示提示界面、弹窗或输出语音提示信息,该提示界面、弹窗或语音提示信息用于提示用户当前正在搜集其相关数据,使得本申请仅仅在获取到用户对该提示界面或者弹窗发出的确认操作后,才开始执行获取用户相关数据的相关步骤,否则(即未获取到用户对该提示界面或者弹窗发出的确认操作时),结束获取用户相关数据的相关步骤,即不获取用户的相关数据。换句话说,本申请所采集的所有用户数据都是在用户同意并授权的情况下进行采集的,且相关用户数据的收集、使用和处理需要遵守相关法律法规和标准。
结合上述介绍和实施环境,对本申请提供的预测模型的训练方法进行说明,以该方法应用于服务器为例进行说明,如图2所示,该方法包括如下步骤210至步骤250。
步骤210,获取第一预测网络。
其中,第一预测网络是通过源域数据在第一预测任务中训练得到的网络,第一预测网络中包括第一编码器。
可选地,源域数据是指从已标注的源域数据集中收集的数据,源域数据通常具有较高质量的任务标签。示意性的,源域数据的数据类型包括文本类型、图像类型、视频类型、音频类型等中的至少一种。
可选地,第一预测任务是指与源域数据相关的任一任务,第一预测任务包括图像分类任务、目标检测任务、语义分割任务、实例分割任务等中的至少一种,本申请实施例对此不加以限定。
可选地,第一编码器用于对源域数据进行编码得到编码数据,编码数据是指源域数据对应的特征表示。可选地,第一预测网络中还包括第一解码器,第一解码器用于对编码数据进行解码得到源域数据在第一预测任务中的任务预测结果。
在一些实施例中,第一预测网络是通过源域样本数据和源域标签对候选预测网络进行训练得到的网络,候选预测网络中包括候选解码器和候选编码器。
可选地,获取第一预测网络的方法包括:获取源域样本数据和源域标签,源域标签用于指示源域样本数据在第一预测任务中的参考结果;通过候选预测网络中的候选编码器对源域样本数据进行编码,得到源域编码数据;通过候选预测网络中的候选解码器对源域编码数据进行解码得到源域预测结果;基于源域预测结果和源域标签之间的差异对候选预测网络进行训练,得到第一预测网络。
示意性的,以目标检测任务为例进行说明的,源域样本数据可实现为在晴天环境下拍摄的道路图像,源域标签实现为该道路图像对应的参考检测框,参考检测框用于框选道路图像中的物体(例如:行人、车辆、房屋等)。将道路图像输入候选预测网络中,通过候选编码器对道路图像进行编码,获取道路图像对应的图像特征表示;通过候选解码器对图像特征表示进行解码,得到道路图像对应的预测检测框;基于预测检测框和参考检测框之间的差异对候选预测网络进行训练,得到第一预测网络,第一预测网络可称为目标检测器,目标检测器具有对晴天环境下(即源域)拍摄的道路图像进行物体检测的能力。
步骤220,对目标域样本数据进行变换处理,得到目标域变换数据,并通过第一编码器对目标域变换数据进行编码处理,得到编码数据。
可选地,目标域样本数据是指从与源域不同的目标域数据集中收集的数据,目标域样本数据通常是不具有任务标签的数据。示意性的,目标域数据的数据类型包括文本类型、图像类型、视频类型、音频类型等中的至少一种。
可选地,源域样本数据和目标域样本数据是存在关联关系的数据。示意性的,源域样本数据实现为在晴天环境下拍摄的道路图像,目标域样本数据实现为雾天环境下拍摄的道路图像;或者,源域样本数据实现为通过正常摄像头拍摄的图像,目标域样本数据实现为通过鱼眼摄像头拍摄的图像;或者,源域样本数据实现为写实风格的图像,目标域样本数据实现卡通风格的图像等。
可选地,对目标域样本数据进行变换处理的情况包括以下情况中的至少一种:
情况一:对目标域样本数据进行掩码处理,得到掩码后的目标域样本数据作为目标域变换数据。
示意性的,以目标域样本数据实现为图像为例进行说明,对图像进行掩码处理,即将图像中的一部分图像块进行遮盖,遮盖后的图像即为掩码后的图像。
将掩码后的目标域样本数据输入第一编码器进行编码,得到掩码编码数据。
情况二:对目标域样本数据进行增强处理,得到增强后的目标域样本数据作为目标域变换数据。
其中,增强后的目标域样本数据包括目标域样本数据对应的正样本数据和负样本数据,正样本数据可理解为与目标域样本数据相似度较高的数据,负样本数据可理解为与目标域样本数据相似度较低的数据。
示意性的,以目标域样本数据实现为图像为例进行说明,对图像进行增强处理的方法包括旋转、缩放、平移、镜像翻转、亮度调整、对比度调整等,本申请实施例对此不加限定。
将目标域样本数据、正样本数据和负样本数据输入第一编码器进行编码,得到目标域样本数据对应的目标域编码数据、正样本数据对应的正样本编码数据和负样本数据对应的负样本编码数据。
需要进行说明的是,上述对目标域样本数据进行变换处理的方法仅为示意性的举例,本申请实施例对变换处理的方法不进行限定。例如:将彩色的目标域图像变换为黑白的目标域图像等。
示意性的,得到目标域变换数据后,将目标域变换数据输入至第一编码器中,通过第一编码器对其进行编码处理,得到编码数据,则该编码数据即为目标域变换数据对应的特征表示。可以理解的是,由于第一编码器是用于执行第一预测任务的第一预测网络中的编码器,则通过该第一编码器得到的编码数据是指与第一预测任务相关的特征数据。
步骤230,基于编码数据和目标域样本数据对第一编码器进行训练,得到第二编码器。
示意性的,基于编码数据和目标域样本数据对第一编码器进行训练,即第一编码器对目标域样本数据进行特征学习,使得第一编码器适应目标域样本数据。
在一些实施例中,若目标域变换数据实现为对目标域样本数据进行掩码处理后的数据,参与训练的模型结构还包括掩码解码器,掩码解码器用于根据解码数据对掩码后的目标域样本数据进行数据重建。则训练得到第二编码器的方法包括:
将掩码编码数据输入掩码解码器进行解码,得到解码数据,解码数据用于表征对掩码后的目标域样本数据进行数据重建得到的数据;基于解码数据和目标域样本数据之间的差异对第一编码器进行训练,得到第二编码器。
示意性的,假设目标域变换数据是指经过遮盖的目标域图像,则对其进行编码后,得到掩码编码数据,将掩码编码数据输入掩码解码器的目的即是将遮盖的图像块进行还原,得到重建图像(即解码数据),根据重建图像和目标域图像之间的差异对第一编码器进行训练,从而使得第一编码器适应目标域样本数据。
在一些实施例中,若目标域变换数据实现为对目标域样本数据进行增强处理后的数据。则对第一编码器进行训练的方法包括:
以最小化目标域编码数据和正样本编码数据之间的差异,同时最大目标域编码数据和负样本编码数据之间的差异为目标,对第一编码器进行训练,得到第二编码器。
其中,单个目标域样本数据对应的正样本数据的数量可以是一个,也可以是多个;单个目标域样本数据对应的负样本数据的数量可以是一个,也可以是多个。
示意性的,针对每个目标域样本数据,通过增强处理,得到多个正样本数据和多个负样本数据;通过第一编码器对单个目标域样本数据、多个正样本数据、多个负样本数据进行编码,得到各个数据分别对应的编码数据,通过拉近目标域样本数据和正样本数据对应的编码数据之间的距离,拉远目标域样本数据和负样本数据对应的编码数据之间的距离对第一编码器进行训练,使得第一编码器适应目标域样本数据。
可选地,进行多次迭代循环训练,直至第一编码器对应的训练损失收敛,停止训练;或者达到预设训练次数,停止训练。停止训练后得到的第一编码器即为第二编码器。
步骤240,基于第二编码器更新第一预测网络,得到教师网络和学生网络。
示意性的,基于第二编码器更新第一预测网络即将第一预测网络中第一编码器替换为第二编码器,即将第一预测网络中的第一编码器的模型参数更新为第二编码器的模型参数,从而得到经过参数更新的第一预测网络;将该第一预测网络拓展为架构相同的教师网络和学生网络,该教师网络和学生网络均以经过参数更新的第一预测网络的模型参数为初始参数。
步骤250,基于源域样本数据和目标域样本数据通过教师网络输出的伪标签,对学生网络进行训练,得到第二预测网络。
其中,第二预测网络用于在第一预测任务中对源域数据或者目标域数据进行预测。
示意性的,教师网络用于对目标域样本数据进行处理得到伪标签,伪标签用于代替目标域标签,即表示目标域样本数据在第一预测任务中的参考结果。学生网络用于对源域样本数据和目标域样本数据进行处理,一方面,通过伪标签对目标域样本数据进行监督;另一方面,通过对源域样本数据和目标域样本数据对应的编码数据进行对齐,从而对学生网络进行训练,提升学生网络在第一预测任务中对目标域数据的预测能力,得到训练后的学生网络。
在一些实施例中,对学生网络进行训练的方法包括以下步骤:
步骤一:对目标域样本数据进行数据强增强处理,得到强增强目标域样本数据;通过学生网络对强增强目标域样本数据进行任务预测,得到目标域预测结果,目标域预测结果用于指示目标域样本数据在第一预测任务中的预测结果。
其中,数据强增强处理通常对目标域样本数据进行大幅度、复杂的变换,以产生与目标域样本数据具有明显差异的数据。
步骤二:对源域样本数据进行数据强增强处理,得到强增强源域样本数据;通过学生网络对强增强源域样本数据进行任务预测,得到源域预测结果,源域预测结果用于指示源域样本数据在第一预测任务中的预测结果。
示意性的,在第一次迭代时,学生网络中包含的编码器为第二编码器,解码器为第一预测网络中的第一解码器,则将强增强目标域样本数据、强增强源域样本数据输入第二编码器进行编码得到目标域编码数据、源域编码数据,通过第一解码器对目标域编码数据、源域编码数据进行解码,得到目标域预测结果和源域预测结果。
步骤三:对目标域样本数据进行数据弱增强处理,得到弱增强目标域样本数据;通过教师网络对弱增强目标域样本数据进行伪标签预测,得到伪标签。
其中,数据弱增强处理是指对目标域样本数据进行轻微、简单的变换,以产生与目标域样本数据仅存在细节差异的数据。
示意性的,强增强处理对数据的调整程度大于弱增强处理对数据的调整程度。例如:强增强处理后的数据与原数据之间的差异度大于弱增强处理后数据与原数据之间的差异度。
可选地,当强增强和弱增强是同一种增强形式时,强增强处理对数据的增强比例大于弱增强处理对数据的增强比例。
示意性的,在第一次迭代时,教师网络中包含的编码器为第二编码器,解码器为第一预测网络中的第一解码器,则将弱增强目标域样本数据输入第二编码器进行编码得到目标域编码数据,通过第一解码器对目标域编码数据进行解码,得到弱增强下的目标域预测结果,将该目标域预测结果作为伪标签。
在一些实施例中,教师网络中还设置有置信度网络,该置信度网络用于预测教师网络产生的伪标签的置信度,可通过设置一个置信度阈值对教师网络产生的伪标签进行筛选。
以第一预测任务实现为目标检测任务为例进行说明,可选地,通过教师网络对弱增强目标域样本数据进行伪标签预测,得到多个候选伪标签以及多个候选伪标签分别对应的分类置信度,多个候选伪标签用于指示对弱增强目标域样本数据进行物体识别后得到的检测框,分类置信度用于指示检测框框选的物体类别的置信度;将多个候选伪标签中分类置信度大于置信度阈值的候选伪标签确定为伪标签。
示意性的,只有当教师网络产生的伪标签的置信度大于置信度阈值时,该伪标签才能被传递到学生网络中,用于学生网络进行训练。
步骤四:基于源域预测结果和源域标签之间的差异,确定源域任务损失;基于目标域预测结果和伪标签之间的差异,确定目标域任务损失。
其中,源域样本数据对应有源域标签,源域标签用于指示源域样本数据在第一预测任务中的参考结果,则获取源域标签后进行源域任务损失的确定。
其中,确定源域任务损失和目标域任务损失的损失函数包括交叉熵损失函数、L1范数损失函数、L2范数损失函数等中的至少一种,本申请实施例对此不加以限定。
步骤五:基于源域任务损失和目标域任务损失对学生网络进行训练,得到训练后的学生网络。
可选地,对源域任务损失和目标域任务损失进行加权融合,得到融合损失;基于融合损失对学生网络进行训练,得到训练后的学生网络。
在一些实施例中,在通过学生网络对强增强目标域样本数据、强增强源域样本数据进行任务预测时,还需要对学生网络提取的目标域样本数据和源域样本数据对应的特征进行对齐处理。则对学生网络进行训练的方法,还包括:通过学生网络对强增强目标域样本数据、强增强源域样本数据进行任务预测时,对强增强目标域样本数据、强增强源域样本数据进行特征对齐识别,得到域特征对齐损失;对源域任务损失、目标域任务损失和域特征对齐损失进行加权融合,得到融合损失;基于融合损失对学生网络进行训练,得到训练后的学生网络。
示意性的,确定域特征对齐损失的过程是类似于生成对抗网络的方式,在学生网络中,为使得源域样本数据和目标域样本数据产生的特征进行对齐,可在学生网络中设置一个判别器,将学生网络产生的特征送入判别器做二分类,分类结果的含义为:该特征属于源域还是目标域;该判别器使用已知的域标签(源域标签为1、目标域标签为0;或者,源域标签为0、目标域标签为1)进行监督,产生判别损失。在对抗训练中,学生网络尽可能产生混淆两个域的特征,使判别器不能分辨,而判别器尽可能分辨两个域的特征,以达到对齐源域样本数据和目标域样本数据对应的特征表示的目的。
可选地,将训练后的学生网络作为第二预测网络;或者,将训练后的学生网络的模型参数迁移至教师网络中,得到参数迁移后的教师网络作为第二预测网络。
可选地,进行多次迭代循环训练,直至学生网络对应的融合损失收敛,停止训练;或者达到预设训练次数,停止训练。停止训练后得到的学生网络即为训练后的学生网络。
可选地,在单次迭代时,对学生网络进行参数更新后,将更新后的学生网络的参数迁移至教师网络中,得到参数更新后的教师网络以进行下一次迭代训练。
综上所述,本申请实施例提供的预测模型的训练方法,获取通过源域数据训练的用于执行第一预测任务的第一预测网络后,对第一预测网络中的第一编码器进行跨域特征学习,即将变换后的目标域样本数据输入第一编码器,用原有的目标域样本数据作为标签对第一编码器进行监督训练,得到第二编码器;后续,根据训练好的第二编码器构建教师网络和学生网络,并基于源域样本数据和目标域样本数据,通过教师网络输出的伪标签对学生网络进行训练,得到在第一预测任务中具备源域数据和目标域数据预测能力的第二预测模型。其中,训练学生网络在第一预测任务中适应目标域数据之前,对教师网络和学生网络的共享编码器进行目标域特征学习,使得教师网络和学生网络能够提前适应目标域特征,从而提高产生的伪标签质量,提升对学生网络的训练效果。
本申请实施例提供的方法,通过掩码自编码的方法或者对比学习的方法对第一编码器进行跨域特征学习,得到第二编码器,使得训练得到的第二编码器能够熟悉目标域数据的特征,从而提高了基于第二编码器扩展得到的教师模型产出的伪标签的质量。
本申请实施例提供的方法,基于源域预测结果和源域标签之间的源域任务损失、目标域预测结果和伪标签之间的目标域任务损失对学生网络进行训练,得到第二预测网络,在提高了第二预测网络对目标域数据的预测能力的同时,兼顾了对源域数据的预测能力,从而提升了模型的通用性。
本申请实施例提供的方法,通过源域样本数据和目标域样本数据之间的域特征对齐损失以及源域任务损失、目标域任务损失对学生网络进行训练,得到第二预测网络,通过加入域对齐损失使得源域样本数据和目标域样本数据对应的特征尽可能处于同一分布,从而提升模型对源域数据和目标域数据的预测准确性。
本申请实施例提供的方法,通过一个置信度阈值来控制产生伪标签的数量和质量,进一步提升对学生网络的训练效果。
在一些实施例中,针对伪标签噪声问题,本申请设计了重训练机制,将学生网络间隔一定训练轮次重新初始化,采用持续更新的教师网络进行重训练,以此允许教师-学生网络跳出由伪标签噪声造成的局部最优。
需要进行说明的是,在学生网络中进行重训练的模块可以为任意模块,本申请实施例对进行重训练的模块数量和类型不加以限定,其中,选取的模型可根据实际的训练效果确定。本申请实施例中,以重训练的模块至少包括解码器为例进行说明。示意性的,如图3所示,上述图2中的实施例还可实现步骤310至步骤353。
步骤310,获取第一预测网络。
其中,第一预测网络是通过源域数据在第一预测任务中训练得到的网络。
可选地,第一预测网络由主干网络、第一编码器和第一解码器组成。
示意性的,主干网络用于标注第一预测网络对应的输入层网络,用于将输入数据转换为输入向量表示;第一编码器用于对主干网络的输出进行编码处理,得到编码数据,第一解码器用于根据第一预测任务的任务目标对编码数据进行解码,得到输入数据在第一预测任务中的预测结果。
步骤320,对目标域样本数据进行变换处理,得到目标域变换数据,并通过第一编码器对目标域变换数据进行编码处理,得到编码数据。
可选地,对目标域样本数据进行掩码处理,得到掩码后的目标域样本数据作为目标域变换数据。
示意性的,参与编码处理的模型结构还包括主干网络,则将目标域样本数据输入主干网络,输出目标域样本数据对应的目标域样本特征表示;对目标域样本特征表示进行掩码处理,得到目标域掩码特征表示作为目标域变换数据;将目标域掩码特征表示输入第一编码器进行编码,得到掩码编码数据。
或者,对目标域样本数据进行增强处理,得到增强后的目标域样本数据作为目标域变换数据。其中,增强后的目标域样本数据包括正样本数据和负样本数据。
示意性的,对目标域样本数据进行增强处理,得到目标域样本数据对应的正样本数据和负样本数据;将正样本数据、负样本数据和目标域样本数据输入主干网络,输出正样本数据对应的正样本特征表示、负样本数据对应的负样本特征表示和目标域样本数据对应的目标域样本特征表示作为目标域变换数据;将正样本特征表示、负样本特征表示和目标域特征表示输入第一编码器进行编码,得到正样本编码数据、负样本编码数据和目标域编码数据。
步骤330,基于编码数据和目标域样本数据对第一编码器进行训练,得到第二编码器。
在一些实施例中,若目标域变换数据实现为对目标域样本数据进行掩码处理后的数据,参与训练的模型结构还包括掩码解码器。则训练得到第二编码器的方法包括:将掩码编码数据输入掩码解码器进行解码,得到解码数据,解码数据用于表示对目标域变换数据进行数据重建得到的数据;基于解码数据和目标域样本数据之间的差异对第一编码器进行训练,得到第二编码器。
可选地,基于解码数据和目标域样本数据之间的差异,确定第一损失;基于第一损失对第一编码器的参数进行更新,得到第二解码器。
可选地,基于第一损失对主干网络的参数进行更新,得到参数更新后的主干网络。
在另一些实施例中,以最小化目标域编码数据和正样本编码数据之间的差异,同时最大目标域编码数据和负样本编码数据之间的差异为目标,对第一编码器进行训练,得到第二编码器。
可选地,基于目标域编码数据和正样本编码数据之间的差异确定第一子损失;基于目标域编码数据和负样本编码数据之间的差异确定第二子损失;以最小化第一子损失、最大化第二子损失为目标,对第一编码器的参数进行更新,得到第二编码器。
可选地,以最小化第一子损失、最大化第二子损失为目标,对主干网络的参数进行更新,得到参数更新后的主干网络。
步骤340,基于第二编码器更新第一预测网络,得到教师网络和学生网络。
示意性的,将第一预测网络中的第一编码器更新为第二编码器,初始状态下,教师网络和学生网络的网络结构均实现为主干网络、第二解码器和第一编码器。
在另一些实施例中,将第一预测网络中的第一编码器更新为第二编码器,将第一预测网络中的主干网络更新为参数更新后的主干网络,初始状态下,教师网络和学生网络的网络结构均实现为参数更新后的主干网络、第二解码器和第一编码器。
步骤351,在第t轮次迭代更新中,基于源域样本数据和目标域样本数据通过第t-1轮次更新得到的教师网络输出的伪标签,对第t-1轮次更新得到的学生网络进行更新,得到第t轮次更新得到的学生网络和第t轮次更新得到的教师网络。
其中,t为大于1的整数。
可选地,对第t-1轮次更新得到的学生网络进行更新,得到第t轮次更新得到的学生网络的方法可参考步骤250中对学生网络进行训练的方法,此处不再赘述。
本实施例中,分别对学生网络的主干网络、编码器和解码器采用特征对齐识别,得到主干网络对应的第一对齐损失、编码器对应的第二对齐损失和解码器对应的第三对齐损失中的至少一种作为域特征对齐损失。
可选地,以域特征对齐损失包含第一对齐损失、第二对齐损失和第三对齐损失为例进行说明,对第t-1轮次更新得到的学生网络进行更新,得到第t轮次更新得到的学生网络的方法包括:
步骤一:对目标域样本数据、源域样本数据进行数据强增强处理,得到强增强目标域样本数据和强增强源域样本数据。
步骤二:将强增强目标域样本数据和强增强源域样本数据输入第t-1轮次更新得到的学生网络中的主干网络,输出目标域样本特征表示和源域样本特征表示。
步骤三:将目标域样本特征表示和源域样本特征表示输入第t-1轮次更新得到的学生网络中的第二编码器中,输出目标域编码数据和源域编码数据。
步骤四:将目标域编码数据和源域编码数据输入第t-1轮次更新得到的学生网络中的第一解码器中,输出目标域解码数据和源域解码数据。
示意性的,目标域解码数据和源域解码数据可以理解为解码器生成的中间特征表示,该中间特征表示包含了第一预测任务的任务目标信息,但不是最终的任务预测结果。
步骤五:基于目标域解码数据,确定目标域预测结果;基于源域解码数据,确定源域预测结果。
以目标检测任务为例进行说明,解码数据可以是检测到的各个物体的特征表示以及置信度;后续,可通过对该物体特征表示进行分析,得到物体类别和位置,并根据各个物体特征对应的置信度确定最终的目标检测结果。
步骤六:对目标域样本数据进行数据弱增强处理,得到弱增强目标域样本数据;通过第t-1轮次更新得到的教师网络对弱增强目标域样本数据进行伪标签预测,得到伪标签。
步骤七:基于目标域样本特征表示和源域样本特征表示之间的差异,确定第一对齐损失;基于目标域编码数据和源域编码数据之间的差异,确定第二对齐损失;基于目标域解码数据和源域解码数据之间的差异,确定第三对齐损失;对第一对齐损失、第二对齐损失和第三对齐损失进行加权融合,得到域特征对齐损失。
示意性的,以第一对齐损失为例进行说明,构建第一对齐损失函数,第一对齐损失函数的训练目标是,最小化判别器对应的域分类损失,同时最大化输入判别器的目标域样本特征表示和源域样本特征表示之间的差异。第一对齐损失函数的公式如下公式一所示:
公式一:
其中,是指第一对齐损失、S是指学生网络的参数(或者,主干网络的参数)、/>是指判别器参数、/>为判别器对应的域分类损失。
需要进行说明的是,第二对齐损失和第三对齐损失的计算方法与第一对齐损失类似,此处不再赘述。其中,每个对齐损失具有的判别器是不同的;或者三种对齐损失对应一个判别器。
可选地,在得到第一对齐损失、第二对齐损失和第三对齐损失后,计算第一对齐损失、第二对齐损失和第三对齐损失之和,得到域特征对齐损失。
步骤八:基于源域预测结果和源域标签之间的差异,确定源域任务损失;基于目标域预测结果和伪标签之间的差异,确定目标域任务损失。
步骤九:对源域任务损失、目标域任务损失和域特征对齐损失进行加权融合,得到融合损失;基于融合损失对第t-1轮次更新得到的学生网络进行训练,得到第t轮次更新得到的的学生网络。
示意性的,基于融合损失采用梯度更新的方式对第t-1轮次更新得到的学生网络中的模型参数进行更新,得到第t轮次更新得到的学生网络。
在一些实施例中,教师网络的参数更新方式不是通过梯度更新,而是通过接收学生网络的模型参数进行更新。
示意性的,采用指数滑动均值(Exponential Moving Average,EMA)的方式将学生网络的参数迁移至教师网络中,使得教师网络更新地更平缓一些。
可选地,在得到第t轮次更新得到的学生网络后,获取第t轮次更新得到的学生网络的第一模型参数;获取第t-1轮次更新得到的教师网络的第二模型参数;根据预设更新参数对第一模型参数和第二模型参数进行加权融合,得到融合模型参数;基于融合模型参数对第t-1轮次更新得到的教师网络进行更新,得到第t轮次更新得到的教师网络。
示意性的,第t轮次更新得到的教师网络的计算公式如下公式二所示:
公式二:θt←αθt+(1-α)θs
其中,θt是指第t-1轮次更新得到的教师网络的第二模型参数,θs是指第t轮次更新得到的学生网络的第一模型参数,α是指超参数(即预设更新参数)。
其中,当α接近1时,EMA方法更加关注最近更新的参数,而历史参数的权重较低;当α接近0时,EMA方法更加关注历史参数,最近更新的参数的权重较低。因此,α的取值可以调节EMA方法对于模型参数变化的敏感度。
步骤352,将第t轮次更新得到的学生网络中编码器的参数初始化为第二编码器的参数,得到第t轮次重置的学生网络。
示意性的,在t个训练轮次后,对学生网络中编码器的参数进行初始化,使其重置为通过步骤330训练得到的第二编码器的参数,从而得到一个重置的学生网络作为第t+1个训练轮次的初始学生网络。
可选地,在学生网络中还包括主干网络的情况下,每间隔t个训练轮次,还需要对学生网络中主干网络的参数进行初始化,使其重置为通过步骤330训练得到的主干网络的参数(或者,重置为步骤310中主干网络的参数),从而得到一个重置的学生网络作为第t+1个训练轮次的初始学生网络。
步骤353,基于源域样本数据和目标域样本数据通过第t轮次更新得到的教师网络输出的伪标签,对第t轮次重置的学生网络进行训练,得到第二预测网络。
示意性的,对第t轮次重置的学生网络进行训练的方法可参考步骤351中得到第t轮次更新得到的学生网络的方法,此处不再赘述。需要进行说明的是,在进行训练时,参与训练的学生网络为第t轮次重置的学生网络,即进行了参数重新初始化(或者参数重置)的学生网络,参与训练的教师网络为第t轮次更新得到的教师网络,即未进行参数重新初始化(或者参数重置)的教师网络。
可选地,每间隔t个训练轮次,进行一次上述步骤352中所述的学生网络参数重置处理,在n×t个训练轮次后,将得到的教师网络作为第二预测网络,其中,n为正整数。
其中,第二预测网络用于在第一预测任务中对源域数据或者目标域数据进行预测。
综上所述,本申请实施例提供的预测模型的训练方法,获取通过源域数据训练的用于执行第一预测任务的第一预测网络后,对第一预测网络中的第一编码器进行跨域特征学习,即将变换后的目标域样本数据输入第一编码器,用原有的目标域样本数据作为标签对第一编码器进行监督训练,得到第二编码器;后续,根据训练好的第二编码器构建教师网络和学生网络,并基于源域样本数据和目标域样本数据,通过教师网络输出的伪标签对学生网络进行训练,得到在第一预测任务中具备源域数据和目标域数据预测能力的第二预测模型。其中,训练学生网络在第一预测任务中适应目标域数据之前,对教师网络和学生网络的共享编码器进行目标域特征学习,使得教师网络和学生网络能够提前适应目标域特征,从而提高产生的伪标签质量,提升对学生网络的训练效果。
本申请实施例提供的方法,使用重训练机制允许学生网络跳出由噪声伪标签带来的局部最优,缓解伪标签中始终存在的噪声对教师-学生网络训练的影响。
本申请实施例提供的方法,通过指数滑动均值的方式使得教师网络中模型参数的更新更为平缓,从而使得教师网络产生的伪标签更为平缓,对比教师网络直接赋值参数的方式,不会因为某次的异常取值而使得教师网络中的模型参数波动很大,提高了伪标签产生的准确度。
在一些实施例中,以目标检测任务为例对预测模型的训练方法进行说明,目标域样本数据包括目标域样本图像。示意性的,如图4所示,上述图2或者图3中的实施例还可实现步骤410至步骤459。
步骤410,获取目标检测器。
其中,目标检测器是通过源域图像和源域标签在目标检测任务中训练得到的。
可选地,目标检测任务是指检测出图像中具有语义信息的特定类别目标(物体),并确定它们的类别和位置。
示意性的,获取源域图像和源域标签,通过源域图像和源域标签对候选检测器进行有监督训练,得到目标检测器,目标检测器用于对图像进行物体检测并将检测到的物体用检测框进行标注,同时在检测框上标注其框选的物体的类别(例如:行人、车辆等)。
可选地,目标检测器由主干网络、第一编码器和第一解码器三部分组成。
上述主干网络可实现为卷积神经网络(Visual Geometry Group,VGG)、残差网络(Residual Network,ResNet)等中的至少一种;上述第一编码器和第一解码器可实现为Transformer结构,例如:Deformable DETR(Deformable Detection Transformer)、DETR(Detection Transformer)、Conditional DETR等中的编码器和解码器结构。
步骤421,通过主干网络提取目标域样本图像对应的目标域特征表示。
即对目标域样本图像进行特征提取,得到目标域样本图像对应的图像特征表示。
其中,目标域样本图像中包括m个图像块,m为大于1的整数;图像特征表示中包括m个图像块分别对应的块特征表示。
示意性的,请参考图5,其示出了一种预测模型的训练流程图,如图5所示,通过主干网络1提取目标域图像对应的图像特征表示。
其中,主干网络1实现为二维视觉特征主干网络,用于对图像提取特征,可采用残差卷积神经网络结构。
步骤422,对目标域特征表示进行掩码处理,得到掩码特征表示。
即对图像特征表示进行掩码处理,得到掩码特征表示,该掩码特征表示即为掩码后的目标域样本数据,掩码特征表示中包括i个被屏蔽的块特征表示,i<n且i为正整数。
示意性的,如图5所示,选定m个图像块对应的块特征表示中部分块特征表示进行屏蔽,得到掩码特征表示501。
步骤423,通过第一编码器对掩码特征表示进行编码处理,得到掩码编码数据。
示意性的,如图5所示,将掩码特征表示501输入编码器1(即第一编码器)进行编码,得到掩码编码数据。可选地,编码器1可实现为可变形注意力机制编码器。其中,可变形注意力机制编码器利用可变形注意力机制聚合图像特征。
步骤424,通过掩码解码器对掩码编码数据进行解码处理,得到解码图像。
其中,解码图像用于表示对掩码特征表示进行重建得到的图像。
可选地,将掩码编码数据输入掩码解码器,对i个被屏蔽的块特征表示对应的图像块进行预测,得到i个预测图像块;基于i个预测图像块得到目标样本图像对应的重建图像,将重建图像作为解码数据。
示意性的,在对掩码编码数据进行解码时,可以仅关注掩码区域,即对i个被屏蔽的块特征表示进行分析,最终得到i个被屏蔽的块特征表示对应的图像块。
示意性的,如图5所示,将掩码编码数据输入辅助解码器进行解码处理,得到解码图像。其中,辅助解码器可实现为MAE解码器。
步骤430,基于解码图像和目标域样本图像之间的差异对第一编码器进行训练,得到第二编码器。
示意性的,如图5所示,基于解码图像和目标域样本图像之间的差异确定Lmask,通过Lmask对编码器1中的参数进行更新,得到编码器2(即第二编码器)。其中,Lmask对应的损失函数可实现为均方误差损失函数或平均绝对误差损失函数等,本申请实施例对此不进行限定。
步骤440,基于第二编码器更新第一预测网络,得到教师网络和学生网络。
示意性的,如图5所示,将通过掩码特征表示501训练得到的编码器2更新至教师网络和学生网络中,得到如图5所示的教师网络510和学生网络520。
步骤451,对目标域样本图像、源域样本图像进行数据强增强处理,得到强增强目标域样本图像和强增强源域样本图像。
可选地,对图像进行增强的方法主要包括几何增强和非几何增强;其中,几何增强会改变图像中像素位置(例如:旋转,翻转,平移等,往往造成图像标注的标签的改变),非几何增强不改变像素位置(例如:增强亮度,颜色,噪声等属性)。基于对图像的增强方法和数量,可将对图像的增强情况分为强增强和弱增强,强增强是指对图像进行几何增强和非几何增强;弱增强是指对图像仅进行几何增强。
示意性的,请参考图5,对目标域图像、源域图像进行数据强增强处理,得到强增强目标域图像和强增强源域图像。
步骤452,将强增强目标域样本图像和强增强源域样本图像输入学生网络中的主干网络,输出目标域样本特征表示和源域样本特征表示。
示意性的,请参考图5,将强增强目标域图像和强增强源域图像输入学生网络520中的主干网络1中,通过主干网络1提取源域图像对应的图像特征表示502和目标域图像对应的图像特征表示503。
步骤453,将目标域样本特征表示和源域样本特征表示输入学生网络中的第二编码器中,输出目标域编码数据和源域编码数据。
示意性的,请参考图5,将源域图像对应的图像特征表示502和目标域图像对应的图像特征表示503输入学生网络520中的编码器2中,输出目标域编码数据和源域编码数据。
步骤454,将目标域编码数据和源域编码数据输入学生网络中的第一解码器中,输出目标域解码数据和源域解码数据。
示意性的,请参考图5,将目标域编码数据和源域编码数据输入学生网络520中的解码器(即第一解码器)中,输出目标域解码数据和源域解码数据。其中解码器可实现为可变形注意力机制解码器。
示意性的,可变形注意力机制解码器利用可变形注意力机制提取候选区域特征,其中候选区域特征即解码得到的各个识别物体对应的区域特征。
步骤455,基于目标域解码数据确定目标域预测结果,基于源域解码数据确定源域预测结果。
示意性的,得到目标域和源域解码的候选区域特征后,进一步对候选区域进行物体类别分类和物体区域回归,得到检测结果(检测框位置和物体类别)。
步骤456,对目标域样本图像进行数据弱增强处理,得到弱增强目标域样本图像,并通过教师网络对弱增强目标域图像进行伪标签预测,得到伪标签。
示意性的,请参考图5,对目标域图像进行数据弱增强处理,得到弱增强目标域图像。将到弱增强目标域图像输入至教师网络510中,通过教师网络510中的主干网络1,编码器2和解码器对其进行分析,识别得到目标域图像中各个物体对应的检测框,将该多个检测框作为多个伪标签。
可选地,在输出多个伪标签时,还输出了伪标签对应的分类置信度,从而通过分类置信度对多个伪标签进行筛选。
示意性的,请参考图5,将教师网络510输出的多个伪标签输入过滤器中,使用置信度阈值,对输出的多个伪标签进行筛选,高于该置信度阈值的伪标签被认为是较为准确的检测框,即将该检测框作为伪标签。
在一些实施例中,上述使用的置信度阈值能够在训练过程中动态变化,且能够针对不同类别进行不同的变化。
可选地,通过教师网络对源域样本数据进行伪标签预测,得到多个源域标签以及多个源域标签分别对应的分类置信度,其中,多个分类置信度对应k个类别,k为正整数;计算第j个类别对应的多个分类置信度的平均置信度作为第j个类别对应的置信度阈值,j≤k且j为正整数;响应于多个候选伪标签中第j个类别对应的候选伪标签大于第j个类别对应的置信度阈值,将第j个类别对应的候选伪标签确定为伪标签。
示意性的,所有类别的置信度阈值都初始化为某一值δ0,以类别c为例,在每一轮训练结束后,统计源域图像在该类别检测框正样本的平均置信度然后更新类别c对应的置信度阈值,更新的公式如下公式三所示:
公式三:
其中,γ,a,b均为超参数,γ控制平均置信度的影响程度;b是0.5,提供一种凸函数,a是一种线性映射,二者共同作用,防止平均置信度的影响使置信度阈值过高或过低。置信度阈值在每一轮训练结束后更新。
可选地,本申请实施例中还硬性规定了所有置信度阈值的上界,当置信度阈值达到该上界时停止更新。
步骤457,基于目标域样本特征表示和源域样本特征表示之间的差异,确定第一对齐损失;基于目标域编码数据和源域编码数据之间的差异,确定第二对齐损失;基于目标域解码数据和源域解码数据之间的差异,确定第三对齐损失;对第一对齐损失、第二对齐损失和第三对齐损失进行加权融合,得到域特征对齐损失。
示意性的,如图5所示,在学生网络520中,主干网络1、编码器2和可变解码器均进行了特征对齐识别,分别产生了对齐损失其加和为/>最终产生域特征对齐损失/>
步骤458,基于源域预测结果和源域标签之间的差异,确定源域任务损失;基于目标域预测结果和伪标签之间的差异,确定目标域任务损失。
示意性的,将输入数据记为源域图像和目标域图像/> 其中Ns是源域图像集样本数量,Nt是目标域图像集样本数量,x是图像,y=(b,c)是目标检测标签,包含了检测框b和对应物体类别c。
本实施例采用Deformable DETR作为基础检测器,源域图像由其本身的标签做监督,使用基础检测器本身的损失函数作为有监督损失的计算公式如下公式四所示:
公式四:
其中,为检测框损失,/>为GIOU损失和/>为分类损失。
示意性的,检测框损失用于衡量预测框与真实框之间的差异,通过比较两个框之间的位置和形状来计算这个损失;首先,计算预测框和真实框的中心点坐标的差值。然后,计算预测框和真实框的宽度和高度的差值。最后,将坐标差值和尺寸差值结合起来,计算出检测框损失。
GIOU损失考虑了框的形状和位置信息,首先,计算预测框和真实框的相交面积和相并面积。然后,计算出相交区域的最小闭合矩形的面积。接着,通过将相交区域的最小闭合矩形的面积除以相并区域的面积来得到GIOU损失。
分类损失用于衡量预测框的类别和真实框的类别之间的差异。使用交叉熵损失函数计算预测框的类别和真实框的类别之间的差异。
目标域图像由教师网络产出的伪标签做监督,仅使用分类损失作为无监督损失其中,/>表示伪标签,包含了检测框/>和对应物体类别
步骤459,对源域任务损失、目标域任务损失和域特征对齐损失进行加权融合,得到融合损失;基于融合损失对学生网络进行训练,得到第二预测网络。
示意性的,如图5所示,对源域任务损失目标域任务损失/>和域特征对齐损失/>进行加权融合,通过融合后的损失对学生网络的参数进行更新。在对对学生网络520的参数进行更新后,通过EMA方法对教师网络510的参数进行更新。
可选地,本申请实施例中,还设计了重训练机制,即每间隔若干训练轮次,重新初始化如图5所示的虚线模块(即学生网络520中的主干网络1和编码器2)参数为通过掩码特征表示501训练得到的解码器的参数,同时保持如图5所示的教师网络510和学生网络520中实线模型的参数不变,重新训练进行训练,使得学生网络520能够跳出由噪声伪标签带来的局部最优。
最终,得到训练好的教师网络作为第二预测模型。
综上所述,本申请实施例提供了一种预测模型的训练方法,通过构建教师-学生网络框架,利用教师网络为学生网络提供伪标签,训练学生网络适应目标域特征;针对训练早期伪标签检测框数量少的问题,设计了掩码自编码器分支,将目标域图像特征遮蔽后送入目标检测器的编码器,使用辅助解码器进行特征重建,用原有特征进行自监督,使模型能够在伪标签检测框数量不足的情况下适应目标域特征;针对伪标签噪声问题,设计了重训练机制,将学生网络间隔一定训练轮次重新初始化,采用持续更新的教师网络进行重训练,以此允许模型跳出由伪标签噪声造成的局部最优。本申请实施例提供的方法可以在无需额外注的前提下训练目标检测模型,适应新的目标域图像特征,维持其性能稳定。
本申请实施例提供的方法,对用于进行伪标签筛选的置信度阈值进行动态更新,其中,置信度阈值与源图像的预测置信度强相关。原因是置信度在训练过程中持续增加,对于这些分数若固定阈值,选择的伪标签对应的检测框的数量会不受限地增加,并引入大量错误。本申请实施例的筛选方法通过置信度的动态更新有助于减少这种错误累积。
本申请实施例提供的预测模型的训练方法训练得到的模型可以应用于道路交通的动态障碍物检测,其优势在于,在源域数据上训练后,适应到新的目标域时,无需额外的目标域数据标注。
示意性的,检测器在晴天环境下训练,即使用如图6所示的晴天的带标注的图像600作为源域数据。其中图像600中是指在晴天环境下拍摄的图像,检测框601用于框选其中的物体,并标注类别“车”(需要进行说明的是,图6中仅示意性的展示的检测框601的类别标注信息,其他检测框的类别标注信息未示出)。
若不使用本申请实施例提供的预测模型的训练方法训练得到的模型方法,其在雾天的检测效果会明显下降,即如图7所示的针对雾天环境(即目标域)的图像的检测效果图像700,可检测出的物体较少,存在漏检情况。
若使用本申请实施例提供的预测模型的训练方法训练得到的模型方法,则检测效果明显提升,即如图8所示的针对雾天环境的图像的检测效果图像800,可检测出的物体明显较多,接近真实结果(即如图9所示的检测效果图像900)。
下面的实验表明,与相关技术中的目标检测方法相比,本申请实施例提供的预测模型的训练方法训练得到的模型可以取得更高的检测准确率。本实验采用了Cityscapes→Foggy Cityscapes、Sim10k→Cityscapes和Cityscapes→BDD100k三个公开评测集。
其中Cityscapes是从城市场景中收集的,包含2975张用于训练和500张用于验证的图像。Foggy Cityscapes通过从Cityscapes中的雾合成算法构建而来。本实验将Cityscapes作为源域,将具有最高雾密度(0.02)的Foggy Cityscapes作为目标域。BDD100k是一个大规模驾驶数据集。其白天子集包含36728张训练图像和5258张验证图像。本实验使用Cityscapes作为源域,将BDD100k白天子集作为目标域。Sim10k是一个来自GTA游戏引擎的合成数据集,包含10000张图像。本实验将Sim10k用作源域,将Cityscapes中的“汽车”实例作为目标域。
实验采用特定交并比(Intersection over Union,IoU)下的精确率来评测目标检测的准确性。该指标衡量的是预测框和正确结果之间IoU大于或等于阈值0.5的百分比(AP@50)。
如表1所示,其中展示了本申请实施例提供的预测模型的训练方法训练得到的模型与相关技术中的目标检测方法的对比实验结果。
表1
方法 Cityscapes→Foggy Sim10k→Cityscapes Cityscapes→BDD100k
方法一 41.3 52.6 28.9
方法二 47.1 53.4 29.4
方法三 43.4 57.9 32.6
方法四 40.7 49.8 29.3
本申请 51.2 62.0 33.5
从表1可以看出,本申请实施例提供的预测模型的训练方法在性能上超越了方法至方法四,取得了显著的改进。
请参考图10,其示出了本申请一个示例性的实施例提供的预测模型的训练装置的结构框图,该装置包括:
获取模块1010,用于获取第一预测网络,所述第一预测网络是通过源域数据在第一预测任务中训练得到的网络,所述第一预测网络中包括第一编码器;
变换模块1020,用于对目标域样本数据进行变换处理,得到目标域变换数据,并通过所述第一编码器对所述目标域变换数据进行编码处理,得到编码数据;
训练模块1030,用于基于所述编码数据和所述目标域样本数据对所述第一编码器进行训练,得到第二编码器;
更新模块1040,用于基于所述第二编码器更新所述第一预测网络,得到教师网络和学生网络;
所述训练模块1030,还用于基于源域样本数据和所述目标域样本数据通过所述教师网络输出的伪标签,对所述学生网络进行训练,得到第二预测网络,所述第二预测网络用于在所述第一预测任务中对源域数据或者目标域数据进行预测。
在一些实施例中,请参考图11,所述变换模块1020,包括:
掩码单元1021,用于对所述目标域样本数据进行掩码处理,得到掩码后的目标域样本数据作为所述目标域变换数据;
编码单元1022,用于将所述掩码后的目标域样本数据输入所述第一编码器进行编码,得到掩码编码数据。
所述训练模块1030,还用于:
将所述掩码编码数据输入掩码解码器进行解码,得到解码数据,所述解码数据用于表征对所述掩码后的目标域样本数据进行数据重建得到的数据;
基于所述解码数据和所述目标域样本数据之间的差异对所述第一编码器进行训练,得到所述第二编码器。
在一些实施例中,所述目标域样本数据包括目标域样本图像,所述目标域样本图像中包括m个图像块,m为大于1的整数;所述掩码单元1021,用于:
对所述目标域样本图像进行特征提取,得到所述目标域样本图像对应的图像特征表示,所述图像特征表示中包括所述m个图像块分别对应的块特征表示;
对所述图像特征表示进行掩码处理,得到掩码特征表示作为所述掩码后的目标域样本数据,所述掩码特征表示中包括i个被屏蔽的块特征表示,i<m且i为正整数。
所述编码单元1022,用于:
将所述掩码编码数据输入所述掩码解码器,对所述i个被屏蔽的块特征表示对应的图像块进行预测,得到i个预测图像块;
基于所述i个预测图像块得到所述目标样本图像对应的重建图像,将所述重建图像作为解码数据。
在一些实施例中,所述变换模块1020,包括:
增强单元1023,用于对所述目标域样本数据进行增强处理,得到增强后的目标域样本数据作为所述目标域变换数据,所述增强后的目标域样本数据包括所述目标域样本数据对应的正样本数据和负样本数据;
所述编码单元1022,用于将所述目标域样本数据、所述正样本数据和所述负样本数据输入所述第一编码器进行编码,得到所述目标域样本数据对应的目标域编码数据、所述正样本数据对应的正样本编码数据和所述负样本数据对应的负样本编码数据。
所述训练模块1030,还用于:
以最小化所述目标域编码数据和所述正样本编码数据之间的差异,同时最大所述目标域编码数据和所述负样本编码数据之间的差异为目标,对所述第一编码器进行训练,得到所述第二编码器。
在一些实施例中,所述训练模块1030,还用于:
在第t轮次迭代更新中,基于所述源域样本数据和所述目标域样本数据通过第t-1轮次更新得到的教师网络输出的伪标签,对第t-1轮次更新得到的学生网络进行更新,得到第t轮次更新得到的学生网络和第t轮次更新得到的教师网络;
将所述第t轮次更新得到的学生网络中编码器的参数初始化为所述第二编码器的参数,得到所述第t轮次重置的学生网络;
基于所述源域样本数据和所述目标域样本数据通过所述第t轮次更新得到的教师网络输出的伪标签,对所述第t轮次重置的学生网络进行训练,得到所述第二预测网络。
在一些实施例中,所述训练模块1030,还用于:
在得到所述第t轮次更新得到的学生网络后,获取所述第t轮次更新得到的学生网络的第一模型参数以及所述第t-1轮次更新得到的教师网络的第二模型参数;
根据预设更新参数对所述第一模型参数和所述第二模型参数进行加权融合,得到融合模型参数;
基于所述融合模型参数对所述第t-1轮次更新得到的教师网络进行更新,得到所述第t轮次更新得到的教师网络。
在一些实施例中,所述训练模块1030,还用于:
对所述源域样本数据、所述目标域样本数据分别进行数据强增强处理,得到强增强源域样本数据、强增强目标域样本数据;
通过所述学生网络对所述强增强源域样本数据、所述强增强目标域样本数据分别进行任务预测,得到源域预测结果和目标域预测结果,所述目标域预测结果用于指示所述目标域样本数据在所述第一预测任务中的预测结果,所述源域预测结果用于指示所述源域样本数据在所述第一预测任务中的预测结果;
获取源域标签,所述源域标签用于指示所述源域样本数据在所述第一预测任务中的参考结果;
对所述目标域样本数据进行数据弱增强处理,得到弱增强目标域样本数据;通过所述教师网络对所述弱增强目标域样本数据进行伪标签预测,得到伪标签;其中,强增强处理对数据的调整程度大于弱增强处理对数据的调整程度;
基于所述源域预测结果和所述源域标签之间的差异,确定源域任务损失;基于所述目标域预测结果和所述伪标签之间的差异,确定目标域任务损失;
基于所述源域任务损失和所述目标域任务损失对学生网络进行训练,得到所述第二预测网络。
在一些实施例中,所述第一预测任务包括目标检测任务;所述训练模块1030,还用于:
通过所述教师网络对所述弱增强目标域样本数据进行伪标签预测,得到多个候选伪标签以及多个候选伪标签分别对应的分类置信度,所述多个候选伪标签用于指示对所述弱增强目标域样本数据进行物体识别后得到的检测框,所述分类置信度用于指示所述检测框框选的物体类别的置信度;
将所述多个候选伪标签中分类置信度大于置信度阈值的候选伪标签确定为所述伪标签。
在一些实施例中,所述训练模块1030,还用于:
通过所述教师网络对所述源域样本数据进行伪标签预测,得到多个源域标签以及多个源域标签分别对应的分类置信度,其中,多个分类置信度对应k个类别,k为正整数;
计算第j个类别对应的多个分类置信度的平均置信度作为所述第j个类别对应的置信度阈值,j≤k且j为正整数;
所述将所述多个候选伪标签中分类置信度大于置信度阈值的候选伪标签确定为所述伪标签,包括:
响应于所述多个候选伪标签中所述第j个类别对应的候选伪标签大于所述第j个类别对应的置信度阈值,将所述第j个类别对应的候选伪标签确定为所述伪标签。
综上所述,本申请实施例提供的预测模型的训练装置,获取通过源域数据训练的用于执行第一预测任务的第一预测网络后,对第一预测网络中的第一编码器进行跨域特征学习,即将变换后的目标域样本数据输入第一编码器,用原有的目标域样本数据作为标签对第一编码器进行监督训练,得到第二编码器;后续,根据训练好的第二编码器构建教师网络和学生网络,并基于源域样本数据和目标域样本数据,通过教师网络输出的伪标签对学生网络进行训练,得到在第一预测任务中具备源域数据和目标域数据预测能力的第二预测模型。其中,训练学生网络在第一预测任务中适应目标域数据之前,对教师网络和学生网络的共享编码器进行目标域特征学习,使得教师网络和学生网络能够提前适应目标域特征,从而提高产生的伪标签质量,提升对学生网络的训练效果。
需要说明的是:上述实施例提供的预测模型的训练装置,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将设备的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的预测模型的训练装置和预测模型的训练方法实施例属于同一构思,其具体实现过程详见方法实施例,这里不再赘述。
图12示出了本申请一个示例性实施例提供的计算机设备的结构示意图。具体来讲包括以下结构:
计算机设备1200包括中央处理单元(Central Processing Unit,CPU)1201、包括随机存取存储器(Random Access Memory,RAM)1202和只读存储器(Read Only Memory,ROM)1203的***存储器1204,以及连接***存储器1204和中央处理单元1201的***总线1205。计算机设备1200还包括用于存储操作***1213、应用程序1214和其他程序模块1215的大容量存储设备1206。
大容量存储设备1206通过连接到***总线1205的大容量存储控制器(未示出)连接到中央处理单元1201。大容量存储设备1206及其相关联的计算机可读介质为计算机设备1200提供非易失性存储。也就是说,大容量存储设备1206可以包括诸如硬盘或者紧凑型光盘只读存储器(Compact Disc Read Only Memory,CD-ROM)驱动器之类的计算机可读介质(未示出)。
不失一般性,计算机可读介质可以包括计算机存储介质和通信介质。计算机存储介质包括以用于存储诸如计算机可读指令、数据结构、程序模块或其他数据等信息的任何方法或技术实现的易失性和非易失性、可移动和不可移动介质。计算机存储介质包括RAM、ROM、可擦除可编程只读存储器(Erasable Programmable Read Only Memory,EPROM)、带电可擦可编程只读存储器(Electrically Erasable Programmable Read Only Memory,EEPROM)、闪存或其他固态存储技术,CD-ROM、数字通用光盘(Digital Versatile Disc,DVD)或其他光学存储、磁带盒、磁带、磁盘存储或其他磁性存储设备。当然,本领域技术人员可知计算机存储介质不局限于上述几种。上述的***存储器1204和大容量存储设备1206可以统称为存储器。
根据本申请的各种实施例,计算机设备1200还可以通过诸如因特网等网络连接到网络上的远程计算机运行。也即计算机设备1200可以通过连接在***总线1205上的网络接口单元1211连接到网络1212,或者说,也可以使用网络接口单元1211来连接到其他类型的网络或远程计算机***(未示出)。
上述存储器还包括一个或者一个以上的程序,一个或者一个以上程序存储于存储器中,被配置由CPU执行。
本申请的实施例还提供了一种计算机可读存储介质,该计算机可读存储介质上存储有至少一条指令、至少一段程序、代码集或指令集,至少一条指令、至少一段程序、代码集或指令集由处理器加载并执行,以实现上述各方法实施例提供的预测模型的训练方法。
本申请的实施例还提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述各方法实施例提供的预测模型的训练方法。
可选地,该计算机可读存储介质可以包括:只读存储器(ROM,Read Only Memory)、随机存取记忆体(RAM,Random Access Memory)、固态硬盘(SSD,Solid State Drives)或光盘等。其中,随机存取记忆体可以包括电阻式随机存取记忆体(ReRAM,Resistance RandomAccess Memory)和动态随机存取存储器(DRAM,Dynamic Random Access Memory)。上述本申请实施例序号仅仅为了描述,不代表实施例的优劣。
本领域普通技术人员可以理解实现上述实施例的全部或部分步骤可以通过硬件来完成,也可以通过程序来指令相关的硬件完成,所述的程序可以存储于一种计算机可读存储介质中,上述提到的存储介质可以是只读存储器,磁盘或光盘等。
以上所述仅为本申请的可选实施例,并不用以限制本申请,凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。

Claims (14)

1.一种预测模型的训练方法,其特征在于,所述方法包括:
获取第一预测网络,所述第一预测网络是通过源域数据在第一预测任务中训练得到的网络,所述第一预测网络中包括第一编码器;
对目标域样本数据进行变换处理,得到目标域变换数据,并通过所述第一编码器对所述目标域变换数据进行编码处理,得到编码数据;
基于所述编码数据和所述目标域样本数据对所述第一编码器进行训练,得到第二编码器;
基于所述第二编码器更新所述第一预测网络,得到教师网络和学生网络;
基于源域样本数据和所述目标域样本数据通过所述教师网络输出的伪标签,对所述学生网络进行训练,得到第二预测网络,所述第二预测网络用于在所述第一预测任务中对源域数据或者目标域数据进行预测。
2.根据权利要求1所述的方法,其特征在于,所述对目标域样本数据进行变换处理,得到目标域变换数据,并通过所述第一编码器对所述目标域变换数据进行编码处理,包括:
对所述目标域样本数据进行掩码处理,得到掩码后的目标域样本数据作为所述目标域变换数据;
将所述掩码后的目标域样本数据输入所述第一编码器进行编码,得到掩码编码数据;
所述基于所述编码数据和所述目标域样本数据对所述第一编码器进行训练,得到第二编码器,包括:
将所述掩码编码数据输入掩码解码器进行解码,得到解码数据,所述解码数据用于表征对所述掩码后的目标域样本数据进行数据重建得到的数据;
基于所述解码数据和所述目标域样本数据之间的差异对所述第一编码器进行训练,得到所述第二编码器。
3.根据权利要求2所述的方法,其特征在于,所述目标域样本数据包括目标域样本图像,所述目标域样本图像中包括m个图像块,m为大于1的整数;
所述对所述目标域样本数据进行掩码处理,得到掩码后的目标域样本数据作为所述目标域变换数据,包括:
对所述目标域样本图像进行特征提取,得到所述目标域样本图像对应的图像特征表示,所述图像特征表示中包括所述m个图像块分别对应的块特征表示;
对所述图像特征表示进行掩码处理,得到掩码特征表示作为所述掩码后的目标域样本数据,所述掩码特征表示中包括i个被屏蔽的块特征表示,i<m且i为正整数;
所述将所述掩码编码数据输入掩码解码器进行解码,得到解码数据,包括:
将所述掩码编码数据输入所述掩码解码器,对所述i个被屏蔽的块特征表示对应的图像块进行预测,得到i个预测图像块;
基于所述i个预测图像块得到所述目标样本图像对应的重建图像,将所述重建图像作为解码数据。
4.根据权利要求1所述的方法,其特征在于,所述对目标域样本数据进行变换处理,得到目标域变换数据,并通过所述第一编码器对所述目标域变换数据进行编码处理,包括:
对所述目标域样本数据进行增强处理,得到增强后的目标域样本数据作为所述目标域变换数据,所述增强后的目标域样本数据包括所述目标域样本数据对应的正样本数据和负样本数据;
将所述目标域样本数据、所述正样本数据和所述负样本数据输入所述第一编码器进行编码,得到所述目标域样本数据对应的目标域编码数据、所述正样本数据对应的正样本编码数据和所述负样本数据对应的负样本编码数据;
所述基于所述编码数据和所述目标域样本数据对所述第一编码器进行训练,得到第二编码器,包括:
以最小化所述目标域编码数据和所述正样本编码数据之间的差异,同时最大所述目标域编码数据和所述负样本编码数据之间的差异为目标,对所述第一编码器进行训练,得到所述第二编码器。
5.根据权利要求1至4任一所述的方法,其特征在于,所述基于源域样本数据和所述目标域样本数据通过所述教师网络输出的伪标签,对所述学生网络进行训练,得到第二预测网络,包括:
在第t轮次迭代更新中,基于所述源域样本数据和所述目标域样本数据通过第t-1轮次更新得到的教师网络输出的伪标签,对第t-1轮次更新得到的学生网络进行更新,得到第t轮次更新得到的学生网络和第t轮次更新得到的教师网络;
将所述第t轮次更新得到的学生网络中编码器的参数初始化为所述第二编码器的参数,得到所述第t轮次重置的学生网络;
基于所述源域样本数据和所述目标域样本数据通过所述第t轮次更新得到的教师网络输出的伪标签,对所述第t轮次重置的学生网络进行训练,得到所述第二预测网络。
6.根据权利要求5所述的方法,其特征在于,所述得到第t轮次更新得到的学生网络和第t轮次更新得到的教师网络,包括:
在得到所述第t轮次更新得到的学生网络后,获取所述第t轮次更新得到的学生网络的第一模型参数以及所述第t-1轮次更新得到的教师网络的第二模型参数;
根据预设更新参数对所述第一模型参数和所述第二模型参数进行加权融合,得到融合模型参数;
基于所述融合模型参数对所述第t-1轮次更新得到的教师网络进行更新,得到所述第t轮次更新得到的教师网络。
7.根据权利要求1至4任一所述的方法,其特征在于,所述基于源域样本数据和所述目标域样本数据通过所述教师网络输出的伪标签,对所述学生网络进行训练,得到第二预测网络,包括:
对所述源域样本数据、所述目标域样本数据分别进行数据强增强处理,得到强增强源域样本数据、强增强目标域样本数据;
通过所述学生网络对所述强增强源域样本数据、所述强增强目标域样本数据分别进行任务预测,得到源域预测结果和目标域预测结果,所述目标域预测结果用于指示所述目标域样本数据在所述第一预测任务中的预测结果,所述源域预测结果用于指示所述源域样本数据在所述第一预测任务中的预测结果;
获取源域标签,所述源域标签用于指示所述源域样本数据在所述第一预测任务中的参考结果;
对所述目标域样本数据进行数据弱增强处理,得到弱增强目标域样本数据;通过所述教师网络对所述弱增强目标域样本数据进行伪标签预测,得到伪标签;其中,强增强处理对数据的调整程度大于弱增强处理对数据的调整程度;
基于所述源域预测结果和所述源域标签之间的差异,确定源域任务损失;基于所述目标域预测结果和所述伪标签之间的差异,确定目标域任务损失;
基于所述源域任务损失和所述目标域任务损失对学生网络进行训练,得到所述第二预测网络。
8.根据权利要求7所述的方法,其特征在于,所述方法还包括:
通过所述学生网络对所述强增强源域样本数据、所述强增强目标域样本数据进行任务预测时,对所述强增强源域样本数据和所述强增强目标域样本数据进行特征对齐识别,得到域特征对齐损失;
所述基于所述源域任务损失和所述目标域任务损失对学生网络进行训练,得到所述第二预测网络,包括:
对所述源域任务损失、所述目标域任务损失和所述域特征对齐损失进行加权融合,得到融合损失;
基于所述融合损失对所述学生网络进行训练,得到所述第二预测网络。
9.根据权利要求8所述的方法,其特征在于,所述第一预测任务包括目标检测任务;
所述通过所述教师网络对所述弱增强目标域样本数据进行伪标签预测,得到伪标签,包括:
通过所述教师网络对所述弱增强目标域样本数据进行伪标签预测,得到多个候选伪标签以及多个候选伪标签分别对应的分类置信度,所述多个候选伪标签用于指示对所述弱增强目标域样本数据进行物体识别后得到的检测框,所述分类置信度用于指示所述检测框框选的物体类别的置信度;
将所述多个候选伪标签中分类置信度大于置信度阈值的候选伪标签确定为所述伪标签。
10.根据权利要求9所述的方法,其特征在于,所述方法还包括:
通过所述教师网络对所述源域样本数据进行伪标签预测,得到多个源域标签以及多个源域标签分别对应的分类置信度,其中,多个分类置信度对应k个类别,k为正整数;
计算第j个类别对应的多个分类置信度的平均置信度作为所述第j个类别对应的置信度阈值,j≤k且j为正整数;
所述将所述多个候选伪标签中分类置信度大于置信度阈值的候选伪标签确定为所述伪标签,包括:
响应于所述多个候选伪标签中所述第j个类别对应的候选伪标签大于所述第j个类别对应的置信度阈值,将所述第j个类别对应的候选伪标签确定为所述伪标签。
11.一种预测模型的训练装置,其特征在于,所述装置包括:
获取模块,用于获取第一预测网络,所述第一预测网络是通过源域数据在第一预测任务中训练得到的网络,所述第一预测网络中包括第一编码器;
变换模块,用于对目标域样本数据进行变换处理,得到目标域变换数据,并通过所述第一编码器对所述目标域变换数据进行编码处理,得到编码数据;
训练模块,用于基于所述编码数据和所述目标域样本数据对所述第一编码器进行训练,得到第二编码器;
更新模块,用于基于所述第二编码器更新所述第一预测网络,得到教师网络和学生网络;
所述训练模块,还用于基于源域样本数据和所述目标域样本数据通过所述教师网络输出的伪标签,对所述学生网络进行训练,得到第二预测网络,所述第二预测网络用于在所述第一预测任务中对源域数据或者目标域数据进行预测。
12.一种计算机设备,其特征在于,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一段计算机程序,所述至少一段计算机程序由所述处理器加载并执行以实现如权利要求1至10任一所述的预测模型的训练方法。
13.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质中存储有至少一段计算机程序,所述至少一段计算机程序由处理器加载并执行以实现如权利要求1至10任一所述的预测模型的训练方法。
14.一种计算机程序产品,其特征在于,包括计算机程序,所述计算机程序被处理器执行时实现如权利要求1至10任一所述的预测模型的训练方法。
CN202311137508.1A 2023-09-04 2023-09-04 预测模型的训练方法、装置、设备、介质及程序产品 Pending CN117217368A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311137508.1A CN117217368A (zh) 2023-09-04 2023-09-04 预测模型的训练方法、装置、设备、介质及程序产品

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311137508.1A CN117217368A (zh) 2023-09-04 2023-09-04 预测模型的训练方法、装置、设备、介质及程序产品

Publications (1)

Publication Number Publication Date
CN117217368A true CN117217368A (zh) 2023-12-12

Family

ID=89041774

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311137508.1A Pending CN117217368A (zh) 2023-09-04 2023-09-04 预测模型的训练方法、装置、设备、介质及程序产品

Country Status (1)

Country Link
CN (1) CN117217368A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117476036A (zh) * 2023-12-27 2024-01-30 广州声博士声学技术有限公司 一种环境噪声识别方法、***、设备和介质
CN117876822A (zh) * 2024-03-11 2024-04-12 盛视科技股份有限公司 应用于鱼眼场景中的目标检测迁移训练方法

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117476036A (zh) * 2023-12-27 2024-01-30 广州声博士声学技术有限公司 一种环境噪声识别方法、***、设备和介质
CN117476036B (zh) * 2023-12-27 2024-04-09 广州声博士声学技术有限公司 一种环境噪声识别方法、***、设备和介质
CN117876822A (zh) * 2024-03-11 2024-04-12 盛视科技股份有限公司 应用于鱼眼场景中的目标检测迁移训练方法
CN117876822B (zh) * 2024-03-11 2024-05-28 盛视科技股份有限公司 应用于鱼眼场景中的目标检测迁移训练方法

Similar Documents

Publication Publication Date Title
WO2023077816A1 (zh) 边界优化的遥感图像语义分割方法、装置、设备及介质
CN113780296B (zh) 基于多尺度信息融合的遥感图像语义分割方法及***
CN117217368A (zh) 预测模型的训练方法、装置、设备、介质及程序产品
CN113780149A (zh) 一种基于注意力机制的遥感图像建筑物目标高效提取方法
CN114332578A (zh) 图像异常检测模型训练方法、图像异常检测方法和装置
CN113033436B (zh) 障碍物识别模型训练方法及装置、电子设备、存储介质
CN117079163A (zh) 一种基于改进yolox-s的航拍图像小目标检测方法
CN114067162A (zh) 一种基于多尺度多粒度特征解耦的图像重构方法及***
CN110991374B (zh) 一种基于rcnn的指纹奇异点检测方法
CN113920379B (zh) 一种基于知识辅助的零样本图像分类方法
CN113989574B (zh) 图像解释方法、图像解释装置、电子设备和存储介质
CN114821299A (zh) 一种遥感图像变化检测方法
US11954917B2 (en) Method of segmenting abnormal robust for complex autonomous driving scenes and system thereof
CN113096070A (zh) 一种基于MA-Unet的图像分割方法
CN113747168A (zh) 多媒体数据描述模型的训练方法和描述信息的生成方法
CN116975347A (zh) 图像生成模型训练方法及相关装置
Yang et al. How to use extra training data for better edge detection?
CN115311598A (zh) 基于关系感知的视频描述生成***
Wang Remote sensing image semantic segmentation network based on ENet
Gou et al. A Semantic Consistency Feature Alignment Object Detection Model Based on Mixed-Class Distribution Metrics
Anilkumar et al. An adaptive multichannel DeepLabv3+ for semantic segmentation of aerial images using improved Beluga whale optimization algorithm
Li et al. A fast detection method for polynomial fitting lane with self-attention module added
CN117808802B (zh) 一种基于多提示引导的通用细粒度视觉计数方法及***
CN117408891B (zh) 一种基于Cycle-GAN的图像加雾方法
CN117893934B (zh) 一种改进的UNet3+网络无人机影像铁路轨道线检测方法与装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication