CN115100461A - 图像分类模型训练方法、装置、电子设备、介质 - Google Patents

图像分类模型训练方法、装置、电子设备、介质 Download PDF

Info

Publication number
CN115100461A
CN115100461A CN202210664512.2A CN202210664512A CN115100461A CN 115100461 A CN115100461 A CN 115100461A CN 202210664512 A CN202210664512 A CN 202210664512A CN 115100461 A CN115100461 A CN 115100461A
Authority
CN
China
Prior art keywords
model
image
update data
trained
pruning
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210664512.2A
Other languages
English (en)
Other versions
CN115100461B (zh
Inventor
刘吉
高志强
章红
周景博
窦德景
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202210664512.2A priority Critical patent/CN115100461B/zh
Publication of CN115100461A publication Critical patent/CN115100461A/zh
Application granted granted Critical
Publication of CN115100461B publication Critical patent/CN115100461B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Multimedia (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Molecular Biology (AREA)
  • Data Mining & Analysis (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本公开提供了一种图像分类模型训练方法、装置、电子设备和介质,涉及数据处理与人工智能领域,尤其涉及图像处理、联邦学习与计算机视觉。一种图像分类模型训练方法包括:接收K个由设备训练的模型的K个模型参数集,由设备训练的模型分别通过对应设备上的样本图像集对相同的待训练模型进行训练而获得,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个;基于分别从K个设备接收的K个设备图像分布指标和K个模型参数集,为第一设备训练的模型确定第一设备模型更新数据,每个设备图像分布指标包括预定图像类别集合的每个类别在对应设备的样本图像集中的出现频率;以及向第一设备发送第一设备模型更新数据。

Description

图像分类模型训练方法、装置、电子设备、介质
技术领域
本公开涉及数据处理与人工智能技术领域,尤其涉及图像处理、联邦学习与计算机视觉,具体涉及一种图像分类模型训练方法、装置、电子设备、计算机可读存储介质和计算机程序产品。
背景技术
人工智能是研究使计算机来模拟人的某些思维过程和智能行为(如学习、推理、思考、规划等)的学科,既有硬件层面的技术也有软件层面的技术。人工智能硬件技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理等技术:人工智能软件技术主要包括计算机视觉技术、语音识别技术、自然语言处理技术以及机器学习/深度学习、大数据处理技术、知识图谱技术等几大方向。
诸如智能手机、平板、智能手表等的设备收集了大量的数据,包括视频和图像等。此外,随着人工智能技术的快速发展,深度学习技术作为其中最为重要的技术之一,往往需要庞大的数据作为基础,而这些智能设备上的数据无疑是十分具有吸引力的。尤其是,在图像识别和图像分类层面,用户设备上的数据将是非常有用的学习资源。然而,将设备上的数据收集起来再集中进行训练可能会带来巨大的传输资源消耗,并且也可能存在隐私泄漏的危险。
在此部分中描述的方法不一定是之前已经设想到或采用的方法。除非另有指明,否则不应假定此部分中描述的任何方法仅因其包括在此部分中就被认为是现有技术。类似地,除非另有指明,否则此部分中提及的问题不应认为在任何现有技术中已被公认。
发明内容
本公开提供了一种图像分类模型训练方法、装置、电子设备、计算机可读存储介质和计算机程序产品。
根据本公开的一方面,提供了一种图像分类模型训练方法,包括:接收K个由设备训练的模型的K个模型参数集,所述K个由设备训练的模型是分别通过位于K个设备中的对应设备上的样本图像集对相同的待训练模型进行训练而获得的,每个样本图像集包括至少一个样本图像,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个,K为正整数;基于分别从所述K个设备接收的K个设备图像分布指标和所述K个模型参数集,为所述K个设备中的至少第一设备中的第一设备训练的模型确定第一设备模型更新数据,其中,所述K个设备图像分布指标中的每一个设备图像分布指标包括所述预定图像类别集合的每个类别在对应的设备上的样本图像集中的出现频率;以及向所述第一设备发送所述第一设备模型更新数据,以使得所述第一设备基于所述第一设备模型更新数据对所述第一设备训练的模型进行更新。
根据本公开的另一方面,提供了一种图像分类模型训练方法,包括:向服务器发送第一模型参数集,所述第一模型参数集用于表征基于样本图像集对待训练模型进行训练以获得的第一模型,所述样本图像集包括至少一个样本图像,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个;从所述服务器接收第一设备模型更新数据,所述第一设备模型更新数据至少基于所述第一模型参数集和第一图像分布指标,所述第一图像分布指标是基于所述样本图像集而确定的,所述第一图像分布指标在从所述服务器接收第一设备模型更新数据之前被发送到所述服务器,并且所述图像分布指标包括所述预定图像类别集合的每个类别在所述样本图像集中的出现频率;以及基于所述第一设备模型更新数据对所述第一模型进行更新。
根据本公开的另一方面,提供了一种图像分类模型训练装置,包括:模型参数接收单元,用于接收K个由设备训练的模型的K个模型参数集,所述K个由设备训练的模型是分别通过位于K个设备中的对应设备上的样本图像集对相同的待训练模型进行训练而获得的,每个样本图像集包括至少一个样本图像,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个,K为正整数;更新数据确定单元,用于基于分别从所述K个设备接收的K个设备图像分布指标和所述K个模型参数集,为所述K个设备中的至少第一设备中的第一设备训练的模型确定第一设备模型更新数据,其中,所述K个设备图像分布指标中的每一个设备图像分布指标包括所述预定图像类别集合的每个类别在对应的设备上的样本图像集中的出现频率;以及更新数据发送单元,用于向所述第一设备发送所述第一设备模型更新数据,以使得所述第一设备基于所述第一设备模型更新数据对所述第一设备训练的模型进行更新。
根据本公开的另一方面,提供了一种图像分类模型训练装置,包括:发送单元,用于向服务器发送第一模型参数集,所述第一模型参数集用于表征基于样本图像集对待训练模型进行训练以获得的第一模型,所述样本图像集包括至少一个样本图像,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个;接收单元,用于从所述服务器接收第一设备模型更新数据,所述第一设备模型更新数据至少基于所述第一模型参数集和第一图像分布指标,所述第一图像分布指标是基于所述样本图像集而确定的,所述第一图像分布指标在从所述服务器接收第一设备模型更新数据之前被发送到所述服务器,并且所述图像分布指标包括所述预定图像类别集合的每个类别在所述样本图像集中的出现频率;以及更新单元,用于基于所述第一设备模型更新数据对所述第一模型进行更新。
根据本公开的另一方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行根据本公开的一个或多个实施例的图像分类模型训练方法。
根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据本公开的一个或多个实施例的图像分类模型训练方法。
根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,其中,所述计算机程序在被处理器执行时实现根据本公开的一个或多个实施例的图像分类模型训练方法。
根据本公开的一个或多个实施例,可以在减少数据传输量和保证数据隐私的情况下实现针对性的模型调节,从而节省数据体量、减少传输所需的资源并且增加计算效率。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图示例性地示出了实施例并且构成说明书的一部分,与说明书的文字描述一起用于讲解实施例的示例性实施方式。所示出的实施例仅出于例示的目的,并不限制权利要求的范围。在所有附图中,相同的附图标记指代类似但不一定相同的要素。
图1示出了根据本公开的实施例的可以在其中实施本文描述的各种方法的示例性***的示意图;
图2示出了根据本公开的实施例的图像分类模型训练方法的流程图;
图3示出了根据本公开的实施例的数据流示意图;
图4示出了根据本公开的实施例的图像分类模型训练方法的流程图;
图5示出了根据本公开的实施例的图像分类模型训练装置的结构框图;
图6示出了根据本公开的实施例的图像分类模型训练装置的结构框图;
图7示出了能够用于实现本公开的实施例的示例性电子设备的结构框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
在本公开中,除非另有说明,否则使用术语“第一”、“第二”等来描述各种要素不意图限定这些要素的位置关系、时序关系或重要性关系,这种术语只是用于将一个元件与另一元件区分开。在一些示例中,第一要素和第二要素可以指向该要素的同一实例,而在某些情况下,基于上下文的描述,它们也可以指代不同实例。
在本公开中对各种所述示例的描述中所使用的术语只是为了描述特定示例的目的,而并非旨在进行限制。除非上下文另外明确地表明,如果不特意限定要素的数量,则该要素可以是一个也可以是多个。此外,本公开中所使用的术语“和/或”涵盖所列出的项目中的任何一个以及全部可能的组合方式。
下面将结合附图详细描述本公开的实施例。
图1示出了根据本公开的实施例可以将本文描述的各种方法和装置在其中实施的示例性***100的示意图。参考图1,该***100包括一个或多个客户端设备101、102、103、104、105和106、服务器120以及将一个或多个客户端设备耦接到服务器120的一个或多个通信网络110。客户端设备101、102、103、104、105和106可以被配置为执行一个或多个应用程序。
在本公开的实施例中,服务器120可以运行使得能够执行根据本公开的图像分类模型训练方法的一个或多个服务或软件应用。
在某些实施例中,服务器120还可以提供其他服务或软件应用,这些服务或软件应用可以包括非虚拟环境和虚拟环境。在某些实施例中,这些服务可以作为基于web的服务或云服务提供,例如在软件即服务(SaaS)模型下提供给客户端设备101、102、103、104、105和/或106的用户。
在图1所示的配置中,服务器120可以包括实现由服务器120执行的功能的一个或多个组件。这些组件可以包括可由一个或多个处理器执行的软件组件、硬件组件或其组合。操作客户端设备101、102、103、104、105和/或106的用户可以依次利用一个或多个客户端应用程序来与服务器120进行交互以利用这些组件提供的服务。应当理解,各种不同的***配置是可能的,其可以与***100不同。因此,图1是用于实施本文所描述的各种方法的***的一个示例,并且不旨在进行限制。
用户可以使用客户端设备101、102、103、104、105和/或106来进行图像分类、训练图像分类模型、查看训练结果或分类结果等。客户端设备可以提供使客户端设备的用户能够与客户端设备进行交互的接口。客户端设备还可以经由该接口向用户输出信息。尽管图1仅描绘了六种客户端设备,但是本领域技术人员将能够理解,本公开可以支持任何数量的客户端设备。
客户端设备101、102、103、104、105和/或106可以包括各种类型的计算机设备,例如便携式手持设备、通用计算机(诸如个人计算机和膝上型计算机)、工作站计算机、可穿戴设备、智能屏设备、自助服务终端设备、服务机器人、游戏***、瘦客户端、各种消息收发设备、传感器或其他感测设备等。这些计算机设备可以运行各种类型和版本的软件应用程序和操作***,例如MICROSOFT Windows、APPLE iOS、类UNIX操作***、Linux或类Linux操作***(例如GOOGLE Chrome OS);或包括各种移动操作***,例如MICROSOFT WindowsMobile OS、iOS、Windows Phone、Android。便携式手持设备可以包括蜂窝电话、智能电话、平板电脑、个人数字助理(PDA)等。可穿戴设备可以包括头戴式显示器(诸如智能眼镜)和其他设备。游戏***可以包括各种手持式游戏设备、支持互联网的游戏设备等。客户端设备能够执行各种不同的应用程序,例如各种与Internet相关的应用程序、通信应用程序(例如电子邮件应用程序)、短消息服务(SMS)应用程序,并且可以使用各种通信协议。
网络110可以是本领域技术人员熟知的任何类型的网络,其可以使用多种可用协议中的任何一种(包括但不限于TCP/IP、SNA、IPX等)来支持数据通信。仅作为示例,一个或多个网络110可以是局域网(LAN)、基于以太网的网络、令牌环、广域网(WAN)、因特网、虚拟网络、虚拟专用网络(VPN)、内部网、外部网、区块链网络、公共交换电话网(PSTN)、红外网络、无线网络(例如蓝牙、WIFI)和/或这些和/或其他网络的任意组合。
服务器120可以包括一个或多个通用计算机、专用服务器计算机(例如PC(个人计算机)服务器、UNIX服务器、中端服务器)、刀片式服务器、大型计算机、服务器群集或任何其他适当的布置和/或组合。服务器120可以包括运行虚拟操作***的一个或多个虚拟机,或者涉及虚拟化的其他计算架构(例如可以被虚拟化以维护服务器的虚拟存储设备的逻辑存储设备的一个或多个灵活池)。在各种实施例中,服务器120可以运行提供下文所描述的功能的一个或多个服务或软件应用。
服务器120中的计算单元可以运行包括上述任何操作***以及任何商业上可用的服务器操作***的一个或多个操作***。服务器120还可以运行各种附加服务器应用程序和/或中间层应用程序中的任何一个,包括HTTP服务器、FTP服务器、CGI服务器、JAVA服务器、数据库服务器等。
在一些实施方式中,服务器120可以包括一个或多个应用程序,以分析和合并从客户端设备101、102、103、104、105和106的用户接收的数据馈送和/或事件更新。服务器120还可以包括一个或多个应用程序,以经由客户端设备101、102、103、104、105和106的一个或多个显示设备来显示数据馈送和/或实时事件。
在一些实施方式中,服务器120可以为分布式***的服务器,或者是结合了区块链的服务器。服务器120也可以是云服务器,或者是带人工智能技术的智能云计算服务器或智能云主机。云服务器是云计算服务体系中的一项主机产品,以解决传统物理主机与虚拟专用服务器(VPS,Virtual Private Server)服务中存在的管理难度大、业务扩展性弱的缺陷。
***100还可以包括一个或多个数据库130。在某些实施例中,这些数据库可以用于存储数据和其他信息。例如,数据库130中的一个或多个可用于存储诸如音频文件和视频文件的信息。数据库130可以驻留在各种位置。例如,由服务器120使用的数据库可以在服务器120本地,或者可以远离服务器120且可以经由基于网络或专用的连接与服务器120通信。数据库130可以是不同的类型。在某些实施例中,由服务器120使用的数据库例如可以是关系数据库。这些数据库中的一个或多个可以响应于命令而存储、更新和检索到数据库以及来自数据库的数据。
在某些实施例中,数据库130中的一个或多个还可以由应用程序使用来存储应用程序数据。由应用程序使用的数据库可以是不同类型的数据库,例如键值存储库,对象存储库或由文件***支持的常规存储库。
图1的***100可以以各种方式配置和操作,以使得能够应用根据本公开所描述的各种方法和装置。
下面参考图2描述根据本公开的示例性实施例的图像分类模型训练方法200。
在步骤S201处,接收K个由设备训练的模型的K个模型参数集,所述K个由设备训练的模型是分别通过位于K个设备中的对应设备上的样本图像集对相同的待训练模型进行训练而获得的,每个样本图像集包括至少一个样本图像,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个,K为正整数。
在步骤S202处,基于分别从所述K个设备接收的K个设备图像分布指标和所述K个模型参数集,为所述K个设备中的至少第一设备中的第一设备训练的模型确定第一设备模型更新数据,其中,所述K个设备图像分布指标中的每一个设备图像分布指标包括所述预定图像类别集合的每个类别在对应的设备上的样本图像集中的出现频率。
在步骤S203处,向所述第一设备发送所述第一设备模型更新数据,以使得所述第一设备基于所述第一设备模型更新数据对所述第一设备训练的模型进行更新。
根据本公开的实施例所述的方法,能够在减少数据传输量和保证数据隐私的情况下实现针对性的模型调节,从而节省数据体量、减少传输所需的资源并且增加计算效率。
分布式模型训练中,不同设备上样本数据存在差异。在对分布式模型训练中的各个设备训练的模型进行处理时,如果不考虑数据差异,可能会导致各个设备的模型不够准确;但是如果读取各个设备的数据,则可能会有数据传输量大或者隐私泄露等问题。
根据本公开的实施例,通过引入设备样本分布指标,可以在不需要直接接收各个设备的样本数据的情况下,了解各个设备的数据分布差异,从而可以针对不同的设备进行不同的参数调整。这里的参数调整可以包括对模型参数(例如,节点或权重)的针对性调整、非结构化剪枝(例如,权重置零)或者结构化剪枝等等,并且本公开不限于此。
由此,一方面可以减少数据传输量,一方面可以保证设备本地的数据隐私性——同时仍然能够获得传统分布式模型训练的好处(比如,节省训练时间、利用多个设备的计算能力并行计算或不需要在一个设备上存储大量样本数据等)。此外,这样的训练结果还可以准确地考虑设备数据差异。
可以理解的是,这里的“预定”是可以随着样本调整、数据更新、引入新设备等等而随时间而变的,而非不能改变或永远固定的。此外,可以理解的是,基于分别从所述K个设备接收的K个设备图像分布指标意指从所述K个设备接收的K个设备图像分布指标的操作可以发生在模型训练之前或者训练之后。例如,可以在训练开始之前就从各设备接收样本分布指标,或者,可以仅在需要剪枝操作的时候才接收这个指标,或者可以定期地接收,等等,并且本公开不限于此。作为一个示例,分布指标可以形如Pk(y),以表示标签y在设备k上的可能性。
此外,可以理解的是,这里的图像处理可以包括本领域技术人员能够理解的各种图像分类、对象识别或目标检测算法等。因而,可以理解的是,“图像类别”在这里可以意指图像作为整体的类别,也可以意指图像中待检测的对象或目标的类别。因而,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个的表述可以涵盖样本图像包含多个待检测目标因而被标记有多个标签的场景。例如,样本图像中含有一只猫一只狗,因此被标记有“猫”和“狗”的标签,并且在这样的示例中,每个标签还可以可选地与对应的检测框或者位置相对应。可以理解的是,以上仅为示例,本公开不限于此。
根据一些实施例,方法200还可以包括:基于所述K个模型参数集分别确定对应的K个期望设备模型剪枝率,每个期望设备模型剪枝率表示对应的设备训练的模型满足特征损失度条件的最大剪枝率。在这样的实施例中,所述第一设备模型更新数据可以是基于所述K个设备图像分布指标和所述K个期望设备模型剪枝率而确定的,所述第一设备模型更新数据用于对所述第一设备训练的模型进行结构化剪枝。根据这样的实施例,对模型进行更新可以包括进行剪枝处理:减少模型体量,增加处理速度,减少运算量。
作为满足特征损失度条件的最大剪枝率的一个示例,可以采用以下标准:记更新后的参数为W′k,计算海森矩阵H(W′k),并将它的特征值按升序排列,即
Figure BDA0003691096790000091
其中dk表示海森矩阵的秩,m表示特征值的索引。定义基函数
Figure BDA0003691096790000092
其中
Figure BDA0003691096790000093
表示损失函数的梯度并记它的李普希茨常数为
Figure BDA0003691096790000094
在本示例中,可以认为第一个满足
Figure BDA0003691096790000095
Figure BDA0003691096790000096
的mk能够避免准确率的减少,并由此计算出期望的剪枝率
Figure BDA0003691096790000097
可以理解的是,以上仅为示例,并且本公开不限于此。作为另一个示例,剪枝率可以是根据不同设备的准确度、模型体量要求、存储空间阈值要求而设定的,可以被预设为相同或不同的百分比,例如30%、50%、60%……等等,或者也可以根据模型数据量、存储空间等来调整,并且本公开不限于此。
如本领域技术人员能够理解的,剪枝率可以指要剪枝的特征值与所有特征值的比率。结构化剪枝是指“剪去”模型中的一些通道,从而减少模型的计算量,增加计算速度。
根据一些实施例,确定第一设备模型更新数据可以包括:基于所述K个期望设备模型剪枝率的加权平均确定全局期望剪枝率,其中每个期望设备模型剪枝率的权重可以基于对应的设备图像分布指标;并且基于所述全局期望剪枝率确定所述第一设备模型更新数据。
根据这样的实施例,能够基于设备样本分布指标获得加权平均的剪枝率,以保证不损失特征的最优化剪枝。
根据一些实施例,方法200还可以包括:通过对所述K个模型参数集进行聚合,获得聚合模型参数集;以及基于共享图像集和所述聚合模型确定期望聚合模型剪枝率。在这样的实施例中,确定第一设备模型更新数据可以包括:基于所述K个期望设备模型剪枝率和所述期望聚合模型剪枝率的加权平均确定全局期望剪枝率。加权平均操作的权重可以基于所述K个设备图像分布指标和共享图像分布指标。所述共享图像分布指标可以包括所述预定图像类别集合中的每个类别在所述共享图像集中的出现频率。方法200还可以包括:基于所述全局期望剪枝率确定所述第一设备模型更新数据。
在分布式模型训练中,往往需要在服务器端对模型进行聚合。在一些数据共享场景下,服务器端可以具有数据。通过基于服务器端的数据,可以更好地反映不同的数据差异。服务器和K个设备进行加权平均以更加全面地考虑全局中数据和特征的影响。可以理解的是,模型聚合可以采用各种本领域已知的模型聚合算法,例如但不限于FedAvg聚合方式。
根据一些实施例,所述第一设备模型更新数据可以包括针对模型的每一层的层剪枝率。在这样的实施例中,所述层剪枝率可以是通过以下操作来确定的:基于所述聚合模型参数集和所述全局期望剪枝率确定剪枝参数阈值;以及通过将所述聚合模型参数集中的每一层的权重参数与所述剪枝参数阈值进行比较,确定每一层的层剪枝率。
作为一个示例,可以将所有的权重参数按照绝对值从小到大排序,并且将在所有权重中按照剪枝率的比率选取的最小的参数的绝对值大小作为阈值。之后,对于每一层卷积层,通过将所有绝对值比阈值小的参数的数量除以该层总的参数数量来得到该层的剪枝率。由此,能够实现更加细化的、对每层不同的剪枝效果。
根据一些实施例,所述第一设备可以被配置成,在接收到所述第一设备模型更新数据后,对所述第一设备训练的模型的每一层:确定该层所输出的特征图的多个秩中的每一个秩的值;从所述多个秩中选择值最小的第一数量的秩作为待被剪枝的秩,所述第一数量可以是基于该层的层剪枝率而确定的;以及去除该层中的与所述待被剪枝的秩对应的过滤器。
这样的操作是基于以下思想:对于一个给定的模型,特征图几乎是不变的。因此,可以假定在服务器上计算得到的秩和边缘设备计算得到的秩相似,因此根据服务器计算得到的特征图来对模型进行剪枝。计算并升序排序Rl中特征图的秩得到Rl。保留Rl中最后
Figure BDA0003691096790000111
秩对应的过滤器,从而,能够获得最高的剪枝率。
根据一些实施例,方法200还可以包括,在接收K个由设备训练的模型的K个模型参数集之前,向所述K个设备中的每个设备发送所述待训练模型。
设备训练的模型可以是基于服务器下发的模型,由此,能够对经过联邦学习和分布式训练的模型进行分别的剪枝操作。作为一个进一步的示例,通过步骤的迭代执行,服务器下发的模型又进而可以是基于接收多个设备训练的模型进行聚合而获得的聚合模型等等。由此,可以实现多轮训练和灵活的剪枝操作。
根据一些实施例,为所述K个设备中的至少第一设备确定第一设备模型更新数据可以包括为所述K个设备中的每个设备确定对应的设备模型更新数据,并且所述方法还可以包括向所述K个设备中的每个设备发送对应的设备模型更新数据以获得K个经更新的设备模型。
可以对所有不同的设备进行不同的更新(剪枝或其他模型调整操作)。在剪枝操作之后可以进一步对模型进行训练,从而获得更准确并且也更加适用于各设备数据的模型。
根据一些实施例,方法200还可以包括,在获得K个经更新的设备模型之后:使得所述K个设备分别基于对应的样本图像集对所述K个经更新的设备模型中的对应模型进行训练。
在剪枝之后还可以进行进一步的训练。也即,剪枝或其他模型修改操作可以发生在训练过程中的任何位置。
联邦学习是一种分布式机器学习技术。针对联邦学习中数据的训练效率问题,提出了一种新的在具有全局共享数据条件下的解决方案。具体地,根据本公开的一个或多个实施例,能够利用边缘设备数据和服务器数据,对模型进行结构化剪枝,以此减少模型的传输量和计算量,加快联邦学习的训练效率,同时保证了模型的精度。
下面结合图3描述根据本公开的一个或多个实施例的数据流。在实际操作过程中,训练过程可以包括多轮,并且每一轮的训练可以包含步骤301-304。以下5个步骤。
在步骤301处,服务器310从所有的设备320-1、320-2、……320-N中随机选择一定比例的设备
Figure BDA0003691096790000121
来训练全局模型,并将全局模型下发到被选择的设备上,其中t表示第t轮,t可以是正整数。
在步骤302处,每个设备320-1、320-2、……320-N(或其中被选择的一些设备)使用本地数据来更新模型,并且在模型更新好后,设备将其上传到服务器。
在步骤303处,在服务器上对上传的模型进行聚合。作为一个示例,可以使用FedAvg聚合方式,但是本公开不限于此。
在步骤304处,确定是否要进行剪枝步骤。例如,可以在特定的轮数,使用设备数据和服务器数据及相应的统计信息来对模型进行剪枝。如果确定本轮应当进行剪枝,则可以生成经剪枝的模型330。否则,可以生成全局模型340。
通过使用设备数据和服务器数据来对模型进行剪枝,能够减少模型的通信量和计算量。
继续考虑图3的联邦学习***作为示例,其中包含1个服务器和N个边缘设备。可以理解的是,虽然示出了多个边缘设备,但是***可以包含更多的或者更少的(甚至,一个)设备,并且本公开不限于此。每个边缘设备会使用本地的数据集来对服务器下发的模型进行训练,服务器则会聚合边缘设备上传的训练好的模型参数,并将聚合后的模型参数进行下发,以让边缘设备继续下一轮的训练。假设设备k具有数据集
Figure BDA0003691096790000131
Figure BDA0003691096790000132
表示数据集
Figure BDA0003691096790000133
的数据量。这里的xk,j表示第k个设备的第j个输入数据,yk,j表示xk,j的标签。整体数据集可以表示为
Figure BDA0003691096790000134
总的样本为
Figure BDA0003691096790000135
Figure BDA0003691096790000136
作为一个示例,训练的目标可以是找到模型参数w来最小化整体数据集上的损失函数。例如,最优目标可以如下表示:
Figure BDA0003691096790000137
其中,
Figure BDA0003691096790000138
是本地损失函数,损失函数f(w,xk,j,yk,j)衡量模型参数w在数据对{xk,j,yk,j}上的误差。
在此,可以使用JS(Jensen–Shannon)散度来表示设备和服务器上的数据非独立同分布的程度,如下所示:
Figure BDA0003691096790000139
其中,
Figure BDA00036910967900001310
中的Pk(y)表示标签y在设备k或服务器上(k=0)的可能性。
Figure BDA00036910967900001311
是KL(Kullback-Leibler)散度,定义如下:
Figure BDA00036910967900001312
较高的非独立同分布程度表明一个设备或服务器和全局数据分布差异更大。数据的统计信息比如Pk在训练过程中可以在设备和服务器间共享,相比于在设备和服务器之间传输原始数据,这会产生非常少的隐私泄露问题。
下面结合表1给出了根据本公开的实施例的算法的一个示例。可以理解的是,以下算法仅为示例,并且本公开不限于此。结合表1,能够利用设备数据和服务器数据,根据模型每一层特征和重要性,在服务器进行独特的剪枝操作,以此来提高联邦学习的训练效率。
表1示例剪枝算法
Figure BDA0003691096790000141
其中,输入可以包括:
L:要剪枝的卷积层的列表
D:所有设备和服务器的集合
w:初始模型
w*:第t轮时的当前模型
W=[υ1,υ2,…,υm]:模型中的参数列表,其中m表示参数的数量并且其中,输出可以包括w′,表示在第t轮剪枝的模型。
如表1中的第2-4行所示,对于每个设备和服务器,使用服务器数据和设备数据计算出期望的剪枝率。作为一个示例,在设备k或者服务器上,给定一个有着初始参数Wk的神经网络。在各个设备上的操作可以是并行的。在进行T轮的训练后,记这个更新后的参数为W′k,相应的差为Δk=Wk-W′k。之后,可以计算海森矩阵的损失函数,即H(W′k),并将它的特征值按升序排列,即
Figure BDA0003691096790000142
其中dk表示海森矩阵的秩,m表示一个特征值的索引。
定义一个基函数
Figure BDA0003691096790000143
其中
Figure BDA0003691096790000144
表示损失函数的梯度并记它的李普希茨常数为
Figure BDA0003691096790000145
计算发现,第一个满足
Figure BDA0003691096790000151
的mk能够避免准确率的减少。由此,可以计算出期望的剪枝率
Figure BDA0003691096790000152
即要剪枝的特征值与所有特征值的比率。
由于数据的非独立同分布,导致每个设备期望剪枝率不相同,可以使用公式(4)在服务器上计算一个聚合的期望剪枝率(例如,见表1的第5行)。
Figure BDA0003691096790000153
其中∈是一个很小的常数用于避免被除数为0。之后,可以参考表1中第6-7行,计算一个阈值
Figure BDA0003691096790000154
来单独为每一层计算剪枝率。
先将所有的权重参数按照绝对值从小到大排序,全局的阈值就是第
Figure BDA0003691096790000158
个参数的绝对值大小。
之后,参考表1第9-11行,可以对于每一层卷积层,通过将所有绝对值比阈值小的参数的数量除以该层总的参数数量来得到该层的剪枝率。
随后,参考表1的第12-15行,可以基于特征图的秩来对模型进行剪枝。记第l层特征图的秩为
Figure BDA0003691096790000155
其中dl表示第l层过滤器的数量。由于对于一个给定的模型,特征图几乎是不变的,因此,假定在服务器上计算得到的秩和边缘设备计算得到的秩相似,并且因此根据服务器计算得到的特征图来对模型进行剪枝。计算并升序排序Rl中特征图的秩得到Rl。保留Rl中最后
Figure BDA0003691096790000156
秩对应的过滤器以获得最高的剪枝率
Figure BDA0003691096790000157
(第14行)。最终,可以将原始模型的层替换为保留下来的过滤器。
下面参考图4描述根据本公开的示例性实施例的图像分类模型训练方法400。
在步骤401处,向服务器发送第一模型参数集,所述第一模型参数集用于表征基于样本图像集对待训练模型进行训练以获得的第一模型,所述样本图像集包括至少一个样本图像,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个。
在步骤402处,从所述服务器接收第一设备模型更新数据,所述第一设备模型更新数据至少基于所述第一模型参数集和第一图像分布指标,所述第一图像分布指标是基于所述样本图像集而确定的,所述第一图像分布指标在从所述服务器接收第一设备模型更新数据之前被发送到所述服务器,并且所述图像分布指标包括所述预定图像类别集合的每个类别在所述样本图像集中的出现频率。
在步骤403处,基于所述第一设备模型更新数据对所述第一模型进行更新。
根据本公开的实施例所述的方法,能够在减少数据传输量和保证数据隐私的情况下实现针对性的模型调节,从而节省数据体量、减少传输所需的资源并且增加计算效率。
可以理解的是,所述第一图像分布指标在从所述服务器接收第一设备模型更新数据之前被发送到所述服务器意指本公开不限于确定和发送第一图像分布指标的时刻。例如,可以并非在每次训练前都确定和发送指标,指标可能早已由本地计算并发送到服务器以供存储了。或者,可以在本次训练之后并且在剪枝之前才触发指标的确定和发送,或者与某次的训练数据一起发送,等等,并且本公开不限于此。
根据一些实施例,基于所述第一设备模型更新数据对所述第一模型进行更新可以包括对所述第一模型进行结构化剪枝,并且其中,所述第一设备模型更新数据可以至少基于所述图像分布指标和期望设备模型剪枝率而被确定,所述期望设备模型剪枝率可以基于所述第一模型参数集并且表示所述第一模型满足特征损失度条件的最大剪枝率。
根据这样的实施例,对模型进行更新可以包括进行剪枝处理:减少模型体量,增加处理速度,减少运算量。特征损失度条件可以包括基于海森矩阵H(W′k)的特征值计算的特征损失度,或者是根据不同设备的准确度、模型体量要求、存储空间阈值要求而设定的,可以被预设为相同或不同的百分比,例如30%、50%、60%……等等,或者也可以根据模型数据量、存储空间等来调整,并且本公开不限于此。
根据一些实施例,所述第一设备模型更新数据可以包括针对模型的每一层的层剪枝率,并且其中,对所述第一模型进行结构化剪枝可以包括,对所述第一模型的每一层:确定该层所输出的特征图的秩中的值最小的第一数量的秩作为待被剪枝的秩,所述第一数量是基于该层的层剪枝率而确定的;以及去除该层中的与所述待被剪枝的秩对应的过滤器。
作为一个示例,可以将所有的权重参数按照绝对值从小到大排序,并且将在所有权重中按照剪枝率的比率选取的最小的参数的绝对值大小作为阈值。之后,对于每一层卷积层,通过将所有绝对值比阈值小的参数的数量除以该层总的参数数量来得到该层的剪枝率。由此,能够实现更加细化的、对每层不同的剪枝效果。
根据一些实施例,方法400还可以包括,在基于样本图像集对待训练模型进行训练以获得第一模型之前:从所述服务器接收所述待训练模型。
根据一些实施例,方法400还可以包括,在对所述第一模型进行更新之后,基于所述样本图像集对经更新的第一模型进行训练。
现在参考图5描述根据本公开的实施例的图像分类模型训练装置500。图像分类模型训练装置500可以包括模型参数接收单元501、更新数据确定单元502和更新数据发送单元503。模型参数接收单元501可以用于接收K个由设备训练的模型的K个模型参数集,所述K个由设备训练的模型是分别通过位于K个设备中的对应设备上的样本图像集对相同的待训练模型进行训练而获得的,每个样本图像集包括至少一个样本图像,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个,K为正整数;
更新数据确定单元502可以用于基于分别从所述K个设备接收的K个设备图像分布指标和所述K个模型参数集,为所述K个设备中的至少第一设备中的第一设备训练的模型确定第一设备模型更新数据,其中,所述K个设备图像分布指标中的每一个设备图像分布指标包括所述预定图像类别集合的每个类别在对应的设备上的样本图像集中的出现频率;以及
更新数据发送单元503可以用于向所述第一设备发送所述第一设备模型更新数据,以使得所述第一设备基于所述第一设备模型更新数据对所述第一设备训练的模型进行更新。
根据本公开的实施例所述的装置,能够在减少数据传输量和保证数据隐私的情况下实现针对性的模型调节,从而节省数据体量、减少传输所需的资源并且增加计算效率。
现在参考图6描述根据本公开的实施例的图像分类模型训练装置600。图像分类模型训练装置600可以包括发送单元601、接收单元602和更新单元603。发送单元601可以用于向服务器发送第一模型参数集,所述第一模型参数集用于表征基于样本图像集对待训练模型进行训练以获得的第一模型,所述样本图像集包括至少一个样本图像,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个。接收单元602可以用于从所述服务器接收第一设备模型更新数据,所述第一设备模型更新数据至少基于所述第一模型参数集和第一图像分布指标,所述第一图像分布指标是基于所述样本图像集而确定的,所述第一图像分布指标在从所述服务器接收第一设备模型更新数据之前被发送到所述服务器,并且所述图像分布指标包括所述预定图像类别集合的每个类别在所述样本图像集中的出现频率。更新单元603可以用于基于所述第一设备模型更新数据对所述第一模型进行更新。
根据本公开的实施例所述的装置,能够在减少数据传输量和保证数据隐私的情况下实现针对性的模型调节,从而节省数据体量、减少传输所需的资源并且增加计算效率。
本公开的技术方案中,所涉及的用户个人信息的收集、获取,存储、使用、加工、传输、提供和公开应用等处理,均符合相关法律法规的规定,且不违背公序良俗。
根据本公开的实施例,还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
根据本公开的另一方面,还提供了一种边缘计算设备,可选的,边缘计算设备除了包括电子设备,还可以包括通信部件等,电子设备可以和通信部件一体集成,也可以分体设置。电子设备可以获取路侧感知设备(如路侧相机)的数据,例如图片和视频等,从而进行图像视频处理和数据计算,再经由通信部件向云控平台传送处理和计算结果。
可选的,边缘计算设备也可以为路侧计算单元(Road Side Computing Unit,RSCU)。可选的,电子设备自身也可以具备感知数据获取功能和通信功能,例如是AI相机,电子设备可以直接基于获取的感知数据进行图像视频处理和数据计算,再向云控平台传送处理和计算结果。
可选的,云控平台在云端执行处理,进行图像视频处理和数据计算,云控平台也可以称为车路协同管理平台、V2X平台、云计算平台、中心***、云端服务器等。
参考图7,现将描述可以作为本公开的服务器或客户端的电子设备700的结构框图,其是可以应用于本公开的各方面的硬件设备的示例。电子设备旨在表示各种形式的数字电子的计算机设备,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图7所示,电子设备700包括计算单元701,其可以根据存储在只读存储器(ROM)702中的计算机程序或者从存储单元708加载到随机访问存储器(RAM)703中的计算机程序,来执行各种适当的动作和处理。在RAM703中,还可存储电子设备700操作所需的各种程序和数据。计算单元701、ROM 702以及RAM 703通过总线704彼此相连。输入/输出(I/O)接口705也连接至总线704。
电子设备700中的多个部件连接至I/O接口705,包括:输入单元706、输出单元707、存储单元708以及通信单元709。输入单元706可以是能向电子设备700输入信息的任何类型的设备,输入单元706可以接收输入的数字或字符信息,以及产生与电子设备的用户设置和/或功能控制有关的键信号输入,并且可以包括但不限于鼠标、键盘、触摸屏、轨迹板、轨迹球、操作杆、麦克风和/或遥控器。输出单元707可以是能呈现信息的任何类型的设备,并且可以包括但不限于显示器、扬声器、视频/音频输出终端、振动器和/或打印机。存储单元708可以包括但不限于磁盘、光盘。通信单元709允许电子设备700通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据,并且可以包括但不限于调制解调器、网卡、红外通信设备、无线通信收发机和/或芯片组,例如蓝牙TM设备、802.11设备、WiFi设备、WiMax设备、蜂窝通信设备和/或类似物。
计算单元701可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元701的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元701执行上文所描述的各个方法和处理,例如方法200和/或400及其变型例等。例如,在一些实施例中,方法200和/或400及其变型例等可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元708。在一些实施例中,计算机程序的部分或者全部可以经由ROM 702和/或通信单元709而被载入和/或安装到电子设备700上。当计算机程序加载到RAM 703并由计算单元701执行时,可以执行上文描述的方法200和/或400及其变型例等的一个或多个步骤。备选地,在其他实施例中,计算单元701可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行方法200和/或400及其变型例等。
本文中以上描述的***和技术的各种实施方式可以在数字电子电路***、集成电路***、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上***的***(SOC)、复杂可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程***上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储***、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储***、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行***、装置或设备使用或与指令执行***、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体***、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的***和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的***和技术实施在包括后台部件的计算***(例如,作为数据服务器)、或者包括中间件部件的计算***(例如,应用服务器)、或者包括前端部件的计算***(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的***和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算***中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将***的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机***可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,也可以为分布式***的服务器,或者是结合了区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本公开中记载的各步骤可以并行地执行、也可以顺序地或以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
虽然已经参照附图描述了本公开的实施例或示例,但应理解,上述的方法、***和设备仅仅是示例性的实施例或示例,本发明的范围并不由这些实施例或示例限制,而是仅由授权后的权利要求书及其等同范围来限定。实施例或示例中的各种要素可以被省略或者可由其等同要素替代。此外,可以通过不同于本公开中描述的次序来执行各步骤。进一步地,可以以各种方式组合实施例或示例中的各种要素。重要的是随着技术的演进,在此描述的很多要素可以由本公开之后出现的等同要素进行替换。

Claims (19)

1.一种图像分类模型训练方法,包括:
接收K个由设备训练的模型的K个模型参数集,所述K个由设备训练的模型是分别通过位于K个设备中的对应设备上的样本图像集对相同的待训练模型进行训练而获得的,每个样本图像集包括至少一个样本图像,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个,K为正整数;
基于分别从所述K个设备接收的K个设备图像分布指标和所述K个模型参数集,为所述K个设备中的至少第一设备中的第一设备训练的模型确定第一设备模型更新数据,其中,所述K个设备图像分布指标中的每一个设备图像分布指标包括所述预定图像类别集合的每个类别在对应的设备上的样本图像集中的出现频率;以及
向所述第一设备发送所述第一设备模型更新数据,以使得所述第一设备基于所述第一设备模型更新数据对所述第一设备训练的模型进行更新。
2.根据权利要求1所述的方法,还包括:基于所述K个模型参数集分别确定对应的K个期望设备模型剪枝率,每个期望设备模型剪枝率表示对应的设备训练的模型满足特征损失度条件的最大剪枝率;并且
其中,所述第一设备模型更新数据是基于所述K个设备图像分布指标和所述K个期望设备模型剪枝率而确定的,所述第一设备模型更新数据用于对所述第一设备训练的模型进行结构化剪枝。
3.根据权利要求2所述的方法,其中,确定第一设备模型更新数据包括:
基于所述K个期望设备模型剪枝率的加权平均确定全局期望剪枝率,其中每个期望设备模型剪枝率的权重基于对应的设备图像分布指标;并且
基于所述全局期望剪枝率确定所述第一设备模型更新数据。
4.根据权利要求2所述的方法,还包括:
通过对所述K个模型参数集进行聚合,获得聚合模型参数集;以及
基于共享图像集和所述聚合模型确定期望聚合模型剪枝率;并且
其中,确定第一设备模型更新数据包括:
基于所述K个期望设备模型剪枝率和所述期望聚合模型剪枝率的加权平均确定全局期望剪枝率,其中,加权平均操作的权重基于所述K个设备图像分布指标和共享图像分布指标,其中,所述共享图像分布指标包括所述预定图像类别集合中的每个类别在所述共享图像集中的出现频率;以及
基于所述全局期望剪枝率确定所述第一设备模型更新数据。
5.根据权利要求4所述的方法,其中,所述第一设备模型更新数据包括针对模型的每一层的层剪枝率,并且其中,所述层剪枝率是通过以下操作来确定的:
基于所述聚合模型参数集和所述全局期望剪枝率确定剪枝参数阈值;以及
通过将所述聚合模型参数集中的每一层的权重参数与所述剪枝参数阈值进行比较,确定每一层的层剪枝率。
6.根据权利要求5所述的方法,其中,所述第一设备被配置成,在接收到所述第一设备模型更新数据后,对所述第一设备训练的模型的每一层:
确定该层所输出的特征图的多个秩中的每一个秩的值;
从所述多个秩中选择值最小的第一数量的秩作为待被剪枝的秩,所述第一数量是基于该层的层剪枝率而确定的;以及
去除该层中的与所述待被剪枝的秩对应的过滤器。
7.根据权利要求1-6中任一项所述的方法,还包括,在接收K个由设备训练的模型的K个模型参数集之前,向所述K个设备中的每个设备发送所述待训练模型。
8.根据权利要求1-7中任一项所述的方法,其中,为所述K个设备中的至少第一设备确定第一设备模型更新数据包括为所述K个设备中的每个设备确定对应的设备模型更新数据,并且所述方法还包括向所述K个设备中的每个设备发送对应的设备模型更新数据以获得K个经更新的设备模型。
9.根据权利要求8所述的方法,还包括,在获得K个经更新的设备模型之后:
使得所述K个设备分别基于对应的样本图像集对所述K个经更新的设备模型中的对应模型进行训练。
10.一种图像分类模型训练方法,包括:
向服务器发送第一模型参数集,所述第一模型参数集用于表征基于样本图像集对待训练模型进行训练以获得的第一模型,所述样本图像集包括至少一个样本图像,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个;
从所述服务器接收第一设备模型更新数据,所述第一设备模型更新数据至少基于所述第一模型参数集和第一图像分布指标,所述第一图像分布指标是基于所述样本图像集而确定的,所述第一图像分布指标在从所述服务器接收第一设备模型更新数据之前被发送到所述服务器,并且所述图像分布指标包括所述预定图像类别集合的每个类别在所述样本图像集中的出现频率;以及
基于所述第一设备模型更新数据对所述第一模型进行更新。
11.根据权利要求10所述的方法,其中,基于所述第一设备模型更新数据对所述第一模型进行更新包括对所述第一模型进行结构化剪枝,并且,
其中,所述第一设备模型更新数据至少基于所述图像分布指标和期望设备模型剪枝率而被确定,所述期望设备模型剪枝率基于所述第一模型参数集并且表示所述第一模型满足特征损失度条件的最大剪枝率。
12.根据权利要求11所述的方法,其中,所述第一设备模型更新数据包括针对模型的每一层的层剪枝率,并且其中,对所述第一模型进行结构化剪枝包括,对所述第一模型的每一层:
确定该层所输出的特征图的秩中的值最小的第一数量的秩作为待被剪枝的秩,所述第一数量是基于该层的层剪枝率而确定的;以及
去除该层中的与所述待被剪枝的秩对应的过滤器。
13.根据权利要求10-12中任一项所述的方法,还包括,在基于样本图像集对待训练模型进行训练以获得第一模型之前:从所述服务器接收所述待训练模型。
14.根据权利要求10-13中任一项所述的方法,还包括,在对所述第一模型进行更新之后,基于所述样本图像集对经更新的第一模型进行训练。
15.一种图像分类模型训练装置,包括:
模型参数接收单元,用于接收K个由设备训练的模型的K个模型参数集,所述K个由设备训练的模型是分别通过位于K个设备中的对应设备上的样本图像集对相同的待训练模型进行训练而获得的,每个样本图像集包括至少一个样本图像,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个,K为正整数;
更新数据确定单元,用于基于分别从所述K个设备接收的K个设备图像分布指标和所述K个模型参数集,为所述K个设备中的至少第一设备中的第一设备训练的模型确定第一设备模型更新数据,其中,所述K个设备图像分布指标中的每一个设备图像分布指标包括所述预定图像类别集合的每个类别在对应的设备上的样本图像集中的出现频率;以及
更新数据发送单元,用于向所述第一设备发送所述第一设备模型更新数据,以使得所述第一设备基于所述第一设备模型更新数据对所述第一设备训练的模型进行更新。
16.一种图像分类模型训练装置,包括:
发送单元,用于向服务器发送第一模型参数集,所述第一模型参数集用于表征基于样本图像集对待训练模型进行训练以获得的第一模型,所述样本图像集包括至少一个样本图像,每个样本图像被标记有预定图像类别集合中的图像类别中的至少一个;
接收单元,用于从所述服务器接收第一设备模型更新数据,所述第一设备模型更新数据至少基于所述第一模型参数集和第一图像分布指标,所述第一图像分布指标是基于所述样本图像集而确定的,所述第一图像分布指标在从所述服务器接收第一设备模型更新数据之前被发送到所述服务器,并且所述图像分布指标包括所述预定图像类别集合的每个类别在所述样本图像集中的出现频率;以及
更新单元,用于基于所述第一设备模型更新数据对所述第一模型进行更新。
17.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-9或10-14中任一项所述的方法。
18.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据权利要求1-9或10-14中任一项所述的方法。
19.一种计算机程序产品,包括计算机程序,其中,所述计算机程序在被处理器执行时实现权利要求1-9或10-14中任一项所述的方法。
CN202210664512.2A 2022-06-13 2022-06-13 图像分类模型训练方法、装置、电子设备、介质 Active CN115100461B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210664512.2A CN115100461B (zh) 2022-06-13 2022-06-13 图像分类模型训练方法、装置、电子设备、介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210664512.2A CN115100461B (zh) 2022-06-13 2022-06-13 图像分类模型训练方法、装置、电子设备、介质

Publications (2)

Publication Number Publication Date
CN115100461A true CN115100461A (zh) 2022-09-23
CN115100461B CN115100461B (zh) 2023-08-22

Family

ID=83290701

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210664512.2A Active CN115100461B (zh) 2022-06-13 2022-06-13 图像分类模型训练方法、装置、电子设备、介质

Country Status (1)

Country Link
CN (1) CN115100461B (zh)

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20210357814A1 (en) * 2020-12-18 2021-11-18 Beijing Baidu Netcom Science And Technology Co., Ltd. Method for distributed training model, relevant apparatus, and computer readable storage medium
CN113971455A (zh) * 2020-07-24 2022-01-25 腾讯科技(深圳)有限公司 一种分布式模型训练方法、装置、存储介质及计算机设备
WO2022037337A1 (zh) * 2020-08-19 2022-02-24 腾讯科技(深圳)有限公司 机器学习模型的分布式训练方法、装置以及计算机设备
WO2022056422A1 (en) * 2020-09-14 2022-03-17 The Regents Of The University Of California Ensemble learning of diffractive neural networks
CN114548298A (zh) * 2022-02-25 2022-05-27 阿波罗智联(北京)科技有限公司 模型训练、交通信息处理方法、装置、设备和存储介质
CN114581932A (zh) * 2022-01-28 2022-06-03 中国电建集团山东电力建设有限公司 一种图片表格线提取模型构建方法及图片表格提取方法

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113971455A (zh) * 2020-07-24 2022-01-25 腾讯科技(深圳)有限公司 一种分布式模型训练方法、装置、存储介质及计算机设备
WO2022037337A1 (zh) * 2020-08-19 2022-02-24 腾讯科技(深圳)有限公司 机器学习模型的分布式训练方法、装置以及计算机设备
WO2022056422A1 (en) * 2020-09-14 2022-03-17 The Regents Of The University Of California Ensemble learning of diffractive neural networks
US20210357814A1 (en) * 2020-12-18 2021-11-18 Beijing Baidu Netcom Science And Technology Co., Ltd. Method for distributed training model, relevant apparatus, and computer readable storage medium
CN114581932A (zh) * 2022-01-28 2022-06-03 中国电建集团山东电力建设有限公司 一种图片表格线提取模型构建方法及图片表格提取方法
CN114548298A (zh) * 2022-02-25 2022-05-27 阿波罗智联(北京)科技有限公司 模型训练、交通信息处理方法、装置、设备和存储介质

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
YAOHUI CAI,WEIZHE HUA: "Structured pruning is all you need for pruning CNNs at Initialization", 《ARXIV:2203.02549V1》 *
廖绍雯;贾聪;: "基于Map-Reduce框架的C4.5分布式改进算法", 自动化与仪器仪表, no. 08 *

Also Published As

Publication number Publication date
CN115100461B (zh) 2023-08-22

Similar Documents

Publication Publication Date Title
CN112579909A (zh) 对象推荐方法及装置、计算机设备和介质
CN112857268B (zh) 对象面积测量方法、装置、电子设备和存储介质
CN114187459A (zh) 目标检测模型的训练方法、装置、电子设备以及存储介质
CN115511779B (zh) 图像检测方法、装置、电子设备和存储介质
WO2023245938A1 (zh) 对象推荐方法和装置
CN114445667A (zh) 图像检测方法和用于训练图像检测模型的方法
CN114723949A (zh) 三维场景分割方法和用于训练分割模型的方法
CN116883181B (zh) 基于用户画像的金融服务推送方法、存储介质及服务器
CN116341680A (zh) 人工智能模型适配方法、装置、电子设备以及存储介质
JP2024507602A (ja) データ処理方法及び予測モデルをトレーニングするための方法
CN113868453B (zh) 对象推荐方法和装置
CN115100461B (zh) 图像分类模型训练方法、装置、电子设备、介质
CN114494797A (zh) 用于训练图像检测模型的方法和装置
CN114120416A (zh) 模型训练方法、装置、电子设备及介质
CN115033782B (zh) 推荐对象的方法、机器学习模型的训练方法、装置和设备
CN116881485B (zh) 生成图像检索索引的方法及装置、电子设备和介质
CN115512131B (zh) 图像检测方法和图像检测模型的训练方法
CN113420227B (zh) 点击率预估模型的训练方法、预估点击率的方法、装置
CN115809364B (zh) 对象推荐方法和模型训练方法
CN115546510A (zh) 图像检测方法和图像检测模型训练的方法
CN118364179A (zh) 资源的推荐方法、资源推荐模型的训练方法、装置、电子设备和介质
CN115829653A (zh) 广告文本的相关度确定方法及装置、设备和介质
CN114021650A (zh) 数据处理方法、装置、电子设备和介质
CN114219079A (zh) 特征选择方法及装置、模型训练方法及装置、设备和介质
CN115292608A (zh) 内容资源推荐方法、神经网络的训练方法、装置和设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant