CN116740505A - 图像分类模型的训练、图像分类方法、装置、机器可读介质及设备 - Google Patents

图像分类模型的训练、图像分类方法、装置、机器可读介质及设备 Download PDF

Info

Publication number
CN116740505A
CN116740505A CN202310818502.4A CN202310818502A CN116740505A CN 116740505 A CN116740505 A CN 116740505A CN 202310818502 A CN202310818502 A CN 202310818502A CN 116740505 A CN116740505 A CN 116740505A
Authority
CN
China
Prior art keywords
image
training
image classification
classification model
text
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310818502.4A
Other languages
English (en)
Inventor
许超
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Yuncong Technology Group Co Ltd
Original Assignee
Yuncong Technology Group Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Yuncong Technology Group Co Ltd filed Critical Yuncong Technology Group Co Ltd
Priority to CN202310818502.4A priority Critical patent/CN116740505A/zh
Publication of CN116740505A publication Critical patent/CN116740505A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • G06N3/0455Auto-encoder networks; Encoder-decoder networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/80Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
    • G06V10/806Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level of extracted features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Multimedia (AREA)
  • Software Systems (AREA)
  • Databases & Information Systems (AREA)
  • Medical Informatics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种图像分类模型的训练方法,包括:获取由多个训练样本构成的训练集,每个所述训练样本包括图像数据和对应所述图像数据的中文描述性数据;所述中文描述性数据包括具有图像类别表示的可学习语义向量;基于小样本学习,利用所述训练集对初始模型进行训练,得到图像分类模型。在本发明中,图像分类模型是通过在训练集中的中文描述性文本样本中加入了可学习语义向量,通过可学习语义向量的学习,提升了图像分类模型的性能,从而使得通过利用该图像分类模型对图像进行分类,分类更加准确。

Description

图像分类模型的训练、图像分类方法、装置、机器可读介质及 设备
技术领域
本发明涉及图像处理领域,具体涉及一种图像分类模型的训练方法、装置、机器可读介质及设备。
背景技术
近年来,随着强大的计算设备(例如,GPU和分布式平台)、大型数据集(例如,ImageNet数据集等)、先进模型和算法(例如,卷积神经网络CNN和循环神经网络RNN)的出现,AI缩短了与人类的差距,并在很多领域有击败人类的例子。比如,AlphaGo在围棋领域击败了人类选手。上述的成功在很大程度上依赖大规模数据的学习。而收集大量样本需要耗费大量的时间和金钱成本,甚至由于道德、隐私或安全问题,很难获取大量样本。因此,少样本学习(Few-shot Learning)被提出就是为了解决少量的有监督信号的样本学习的问题。在进行少样本学习时,使用固定的、静态的提示学习模板来进行学习,而这样方式在针对不同的数据集下效果不佳。
发明内容
鉴于以上所述现有技术的缺点,本发明的目的在于提供一种图像分类模型的训练方法、装置、机器可读介质及设备,用于解决现有技术存在的问题。
为实现上述目的及其他相关目的,本发明提供一种图像分类模型的训练方法,所述训练方法包括:
获取由多个训练样本构成的训练集,每个所述训练样本包括图像数据和对应所述图像数据的中文描述性数据;所述中文描述性数据包括具有图像类别表示的可学习语义向量;
基于小样本学习,利用所述训练集对初始模型进行训练,得到图像分类模型。
于本发明一实施例中,所述利用所述训练集对初始模型进行训练,包括:
通过所述初始模型中的特征提取层对所述训练样本进行特征提取,得到第一文本特征和图像特征;
通过所述初始模型中的上下文感知层利用多头注意力机制使得图像特征为第一文本特征添加注意力,得到第二文本特征;
通过所述初始模型中的相似度度量层计算第二文本特征与所述图像特征之间的相似度;
基于所述相似度构建损失函数,并根据所述损失函数对所述初始模型进行迭代训练,得到图像分类模型。
于本发明一实施例中,通过特征提取层对所述训练样本进行特征提取,包括:
通过特征提取层中的视觉编码器对所述训练样本中的图像数据进行特征提取,得到图像特征;
通过特征提取层是的文本编码器对所述训练样本中的中文描述性数据进行特征提取,得到第一文本特征;
其中,所述文本编码器和所述视觉编码器是以图像样本和对应所述图像样本的中文描述性文本样本构成的训练集对图像特征和第一文本特征进行对比学习训练得到。
于本发明一实施例中,利用多头注意力机制使得图像特征为第一文本特征添加注意力,得到第二文本特征,包括:
对所述图像特征进行全局池化处理,得到全局特征;
将所述图像特征和所述全局特征进行特征融合,得到第一融合特征;
将所述第一融合特征输入到多头注意力网络中,得到第二融合特征;
将所述图像特征和所述第二融合特征进行特征融合,得到第二文本特征。
于本发明一实施例中,所述图像类别表示位于所述可学习语义向量的开始位置、中间位置或结尾位置。
于本发明一实施例中,所述损失函数包括交叉熵损失函数和散度损失函数,其中,交叉熵损失函数用于约束第一预测分类值和真实分类值,所述散度损失函数用于约束第一预测分类值和第二预测分类值,第二预测分类值为零样本预测分类值。
于本发明一实施例中,在训练所述图像分类模型的过程中,固定所述视觉编码器与所述文本编码器的参数,对所述可学习语义向量进行更新。
为实现上述目的及其他相关目的,本发明提供一种图像分类模型的训练装置,所述训练装置包括:
数据获取模块,用于获取由多个训练样本构成的训练集,每个所述训练样本包括图像数据和对应所述图像数据的中文描述性数据;所述中文描述性数据包括具有图像类别表示的可学习语义向量;
训练模块,用于基于小样本学习,利用所述训练集对初始模型进行训练,得到图像分类模型。
为实现上述目的及其他相关目的,本发明提供一种图像分类方法,所述分类方法包括:
获取待分类图像;
将所述待分类图像输入到所述的图像分类模型中,以所述图像分类模型的输出作为待分类图像的类别。
为实现上述目的及其他相关目的,本发明提供一种图像分类装置,所述分类装置包括:
图像获取模块,用于获取待分类图像;
图像分类模块,用于将所述待分类图像输入到所述的图像分类模型中,以所述图像分类模型的输出作为待分类图像的类别。
为实现上述目的及其他相关目的,本发明还提供一种电子设备,包括:
一个或多个处理器;和
其上存储有指令的一个或多个机器可读介质,当所述一个或多个处理器执行时,使得所述设备执行前述的一个或多个所述的图像分类模型的训练方法或图像分类方法。
为实现上述目的及其他相关目的,本发明还提供一个或多个机器可读介质,其上存储有指令,当由一个或多个处理器执行时,使得设备执行前述的一个或多个所述的图像分类模型的训练方法或所述的图像分类方法。
如上所述,本发明提供的一种图像分类模型的训练方法、装置、机器可读介质及设备、装置、机器可读介质及设备,具有以下有益效果:
本发明的一种图像分类模型的训练方法,包括:获取由多个训练样本构成的训练集,每个所述训练样本包括图像数据和对应所述图像数据的中文描述性数据;所述中文描述性数据包括具有图像类别表示的可学习语义向量;基于小样本学习,利用所述训练集对初始模型进行训练,得到图像分类模型。在本发明中,图像分类模型是通过在训练集中的中文描述性文本样本中加入了可学习语义向量,通过可学习语义向量的学习,提升了图像分类模型的性能,从而使得通过利用该图像分类模型对图像进行分类,分类更加准确。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本申请。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对本发明实施例的描述中所需要使用的附图作简单的介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本发明的一示例性实施例示出的图像分类模型的训练方法的实施环境示意图;
图2为本发明的一示例性实施例示出的图像分类模型的训练方法的流程图;
图3为本发明一示例性实施例示出的利用训练集对初始模型进行训练的流程图;
图4为本发明的一示例性实施例示出的利用多头注意力机制使得图像特征为第一文本特征添加注意力的流程图;
图5为本发明的一示例性实施例示出的图像分类模型的训练装置的框图
图6为本发明的一示例性实施例示出的一种图像分类方法流程图;
图7为本发明的一示例性实施例示出的一种图像分类装置的框图;
图8为本发明的一实施例中终端设备的硬件结构示意图;
图9为本发明的一实施例中终端设备的硬件结构示意图。
具体实施方式
以下将参照附图和优选实施例来说明本发明的实施方式,本领域技术人员可由本说明书中所揭露的内容轻易地了解本发明的其他优点与功效。本发明还可以通过另外不同的具体实施方式加以实施或应用,本说明书中的各项细节也可以基于不同观点与应用,在没有背离本发明的精神下进行各种修饰或改变。应当理解,优选实施例仅为了说明本发明,而不是为了限制本发明的保护范围。
需要说明的是,以下实施例中所提供的图示仅以示意方式说明本发明的基本构想,遂图式中仅显示与本发明中有关的组件而非按照实际实施时的组件数目、形状及尺寸绘制,其实际实施时各组件的型态、数量及比例可为一种随意的改变,且其组件布局型态也可能更为复杂。
在下文描述中,探讨了大量细节,以提供对本发明实施例的更透彻的解释,然而,对本领域技术人员来说,可以在没有这些具体细节的情况下实施本发明的实施例是显而易见的,在其他实施例中,以方框图的形式而不是以细节的形式来示出公知的结构和设备,以避免使本发明的实施例难以理解。
近年来,随着强大的计算设备(例如,GPU和分布式平台)、大型数据集(例如,ImageNet数据集等)、先进模型和算法(例如,卷积神经网络CNN和循环神经网络RNN)的出现,AI缩短了与人类的差距,并在很多领域有击败人类的例子。比如,AlphaGo在围棋领域击败了人类选手。上述的成功在很大程度上依赖大规模数据的学习。相比之下,人类可以利用少量的样本进行快速学习。例如,给小朋友几张没有见过的动物图像,经过学习,他能很迅速地从一堆样本中找到刚刚学习的新动物的相似图像。另一方面,有些时候,收集大量样本需要耗费大量的时间和金钱成本,甚至由于道德、隐私或安全问题,很难获取大量样本。因此,少样本学习(Few-shot Learning)被提出就是为了解决了从少量的有监督信号的样本学习的问题。比如,某些特殊车辆的识别、流行新款衣服的识别等。
一个去构建视觉识别***的常见做法就是:去训练一个视觉模型然后去预测一个由离散标签组成的固定类别,比如10类,100类等。这种学习的方式限制了类别数,成为一个“闭集”的分类问题,这就导致在出现新类别的时候,就需要额外的数据重新训练。因此,构建一个“开集”的识别问题具有很好的扩展性。近年来,视觉-语言预训练模型比如CLIP、ALIGN的出现,给视觉表征学习带来了新的思路。它们的主要思想为:为两种模态(文本和图像)分别构建编码器,然后利用对比学习训练思想去对齐这两种模态。在推理阶段,由于缺乏文本信息,借助prompt learning(提示学习)方法来解决输入文本问题。
在传统的视觉分类中,网络的类别数是预先定义好的,比如10分类、100分类,这个主要是由网络的全连接层所决定的。只有提前预定义好网络类别数,网络的输出和真实的标签才可以利用离散的形式进行优化。随着类似CLIP的视觉-文本预训练模型的兴起,可以利用对比学习训练相关方法来对齐文本和图像空间,这样的话就不需要设定一个固定的类别数了,具有良好的可扩展性;在深度学习的范畴下,往往需要大量的训练数据才可以有效地工作。标注和收集大量的数据样本往往需要昂贵的时间和金钱成本,因此小样本学习的出发点就是利用少量样本的方式进行模型的优化,具有较大的实际应用价值。
最早的CLIP工作只是单纯的利用图像类别信息,然后再使用提示学习(promptlearning)方法将类别变成一个句子,比如“aphoto of{class_name}”。当零样本推理(也就是没有样本参与训练)效果不好的时候,需要利用少量样本(比如1个样本、2个样本等)优化特定场景下的模型。在不同的数据集下,使用生硬的、静态的提示学习模板不一定是最佳的选择。
同时在基于视觉-文本预训练模型的范式下,目前的应用场景主要是英文,由于之前的那些模型训练基本上都是在英文上训练,如果直接应用的话,需要将其先翻译成英文然后才能使用。由于中文存在某些特殊的遣词、语境,很多时候翻译成英文很难表达出原来的意思,而且翻译本身就需要专业的领域知识和时间成本。所以,一种适配中文的视觉-文本预训练模型还是十分有必要的。
针对上述问题,本申请的实施例分别提出一种图像分类模型的训练方法、一种图像分类模型的训练装置、一种图像分类方法、一种图像分类装置、一种电子设备、一种计算机可读存储介质。
首先,对本说明书一个或多个实施例涉及的名词术语进行解释。
CLIP:Contrastive Language-Image Pre-training在大数据和大模型的前提下,利用对比学习训练的思想进行高效的训练,在多个数据集取得了非常不错的效果,无论是zero shot、few shot任务上;
ALIGN:A Large-scale ImaGe and Noisy-text embedding,利用大规模噪声图像-文本数据来扩大视觉和视觉-语言的表示学习。作者避免了对数据预处理和标注的工作量,只需要基于数据频率的简单过滤。在这个数据集上,作者基于对比学习训练损失函数训练一个非常简单的双编码器模型ALIGN;
ResNet:Deep Residual Learning for Image Recognition,该系列网络广泛用于目标分类等领域以及作为计算机视觉任务主干经典神经网络的一部分,典型的网络有resnet50,resnet101等。Resnet网络证明网络能够向更深(包含更多隐藏层)的方向发展。
Transformer:Attention IsAllYouNeed,transfomer的最大特点是抛弃了传统的CNN和RNN,整个网络结构完全是由self-Attention机制组成。由于其出色性能以及对下游任务的友好性,从而广泛应用于NLP领域,例如机器翻译,问答***,文本摘要和语音识别等等方向。
BERT(Bidirectional Encoder Representations from Transformer,双向语义编码)一种对Transformer的优化神经网络模型,通过注意力机制提取、分析自然语言文本。
图1是本申请一种示例性图像分类模型的训练方法实施环境的示意图。请参阅图1,该实施环境中包括终端设备110和服务器120,终端设备110和服务器120之间通过有线或者无线网络进行通信。终端设备可以多个图像,然后基于多个图像构建训练集,其中,每个训练样本包括图像数据和对应图像数据的中文描述性数据;中文描述性数据包括具有图像类别表示的可学习语义向量。在终端设备或/和服务器中可以设置初始模型,基于小样本学习,利用所述训练集对初始模型进行训练,得到图像分类模型。在本发明中,图像分类模型是通过在训练集中的中文描述性文本样本中加入了可学习语义向量,通过可学习语义向量的学习,提升了图像分类模型的性能,从而使得通过利用该图像分类模型对图像进行分类,分类更加准确。
应该理解,图1中的终端设备110和服务器120的数目仅仅是示意性的。根据实际需要,可以具有任意数目的终端设备110和服务器120。
其中,终端设备110对应客户端,其可以是任意具有用户输入接口的电子设备,包括但不限于智能手机、平板、笔记本电脑、计算机、车载电脑等等,其中,用户输入接口包括但不限于触摸屏、键盘、物理按键、音频拾取装置等。
其中,服务器120对应服务端,其可以是提供各种服务的服务器,其可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式***,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、CDN(Content DeliveryNetwork,内容分发网络)以及大数据和人工智能平台等基础云计算服务的云服务器,本处不对此进行限制。
终端设备110可以通过3G(第三代的移动信息技术)、4G(***的移动信息技术)、5G(第五代的移动信息技术)等无线网络与服务端120进行通信,本处也不对此进行限制。
请参阅图2,图2是本申请的一示例性实施例示出的一种图像分类模型的训练方法的流程图。该图像分类模型的训练方法可以应用于图1所示的实施环境,并由该实施环境中的服务器120具体执行。应理解的是,该图像分类模型的训练方法也可以适用于其它的示例性实施环境,并由其它实施环境中的设备具体执行,本实施例不对该图像分类模型的训练方法所适用的实施环境进行限制。
请参阅图2,图2为本申请一示例性的一种图像分类模型的训练方法的流程图,该图像分类模型的训练方法至少包括步骤S210至步骤S220,详细介绍如下:
步骤S210,获取由多个训练样本构成的训练集,每个所述训练样本包括图像数据和对应所述图像数据的中文描述性数据;所述中文描述性数据包括具有图像类别表示的可学习语义向量;
步骤S220,基于小样本学习,利用所述训练集对初始模型进行训练,得到图像分类模型。
在本发明中,图像分类模型是通过在训练集中的中文描述性文本样本中加入了可学习语义向量,通过可学习语义向量的学习,提升了图像分类模型的性能,从而使得通过利用该图像分类模型对图像进行分类,分类更加准确。
在本发明中,假设f是将输入图像x经过图像编码器获取,同时是由文本编码器得到的权重向量。K是图像类别数,wi是来自于“aphoto ofa[CLASS]”的提示模板(提示模板的含义指的是将单词变成一句话的范式(模板)。将单词变成一句话主要有两个方面的考虑:1、由于在中文中一个单词在不同的语境中拥有不同的含义,因此只有处于特定的句子的时候单词的意义才更加准确;2、由于训练的时候是句子级别的训练的,所以的话,在测试的时候也需要将其变成一个句子)。这里的CLASS最终在实现的时候被一些特定的名字所代替,比如“狗”、“猫”、“汽车”等。最后的预测概率可以由如下公式描述:
其中,τ是温度系数,控制着分布的形状,而cso(.,.)代表着余弦相似度。
为了替换生硬、静态的提示模板,这里引入了可学习语义向量,它由如下连续向量构成:t=[V]1[V]2...[V]M[CLASS]
其中,[V]m(m∈{1,2,3...,M})代表着和单词词嵌同样的维度,M是超参数控制着语义向量的维度。通过引入文本编码器g(.)来处理上述t的输入。引入可学习语义向量之后的预测概率可由如下公式描述:
实际使用过程中,可学习语义向量中的[CLASS]的位置可以是在结尾,也可以是中间,也可以是开头。比如一个句子,首先经过一个词嵌层(比如BPE网络)变成一个统一长度的向量(英文的长度是77,中文的长度是52),然后的话这个统一长度向量再经过一个embedding网络,提取这个向量的特征(比如维度是52x512)。之前的方法它这个是向量是固定的。但是这里是小样本学习,只有很少的样本(比如一张图像、2张图像等),无法更新整体的网络,所以就更新52x512这个向量。因此,在训练图像分类模型的过程中,固定视觉编码器与文本编码器的参数,对可学习语义向量进行更新。
具体来说,对其中16x512部分向量变成可学***均。
需要说明的是,在本发明中,训练样本即需进行分类的图像,训练样本可以是各个类型的图像,可以是风景图像,也可以是人物图像,本发明实施例对此不作具体限定。
图像数据,即训练样本中的对象,可以通过目标检测模型来对训练样本中的对象进行检测。
中文描述性数据为描述图像数据的文本,中文描述性数据与图像数据对应。例如,训练样本为包含有苹果的图像,则图像数据为苹果,中文描述性数据可以是“红色果实”。
在本实施例中,训练样本的形式包括图片、视频帧等,在此不作限定。而中文描述性数据的形式包括单字、词组、短语、句子、段落文章等,在此不作限定。
在一实施例中,所述图像类别表示位于所述可学习语义向量的开始位置、中间位置或结尾位置。
请参阅图3,图3为本发明一示例性实施例示出的利用训练集对初始模型进行训练的流程图。在图3中,利用训练集对初始模型进行训练,包括:
步骤S310,通过初始模型中的特征提取层对所述训练样本进行特征提取,得到第一文本特征和图像特征;
具体地,通过特征提取层对所述训练样本进行特征提取,包括:
通过特征提取层中的视觉编码器对所述训练样本中的图像数据进行特征提取,得到图像特征;
通过特征提取层是的文本编码器对所述训练样本中的中文描述性数据进行特征提取,得到第一文本特征;
其中,所述文本编码器和所述视觉编码器是以图像样本和对应所述图像样本的中文描述性文本样本构成的训练集对图像特征和第一文本特征进行对比学习训练得到。
在本发明中,特征提取层包括视觉编码器和文本编码器,其中,视觉编码器可以选择残差网络(ResNet)或者Transformer模型,文本编码器可以选择常用Transformer模型,比如Bert。通过特征提取层对所述训练样本进行特征提取,即基于残差网络ResNet对图像数据进行编码,得到图像数据的高层特征表示,即图像特征,以及基于Transformer模型对中文描述性数据进行编码,得到文本数据的高层特征表示,即第一文本特征。
具体地,基于残差网络ResNet对图像数据进行编码,包括:
对图像数据(即图片)进行预处理,设定图片输入分辨率,在图片缩放的基础上,采用中心裁剪的方法对图片进行裁剪,对缩放裁剪的图片进行归一化处理;通过提取归一化处理后的图像数据不同维度的特征来构成特征集;选取样本点并提取样本点的M维特征,每个样本的特征是一个大小为MXN的矩阵,使用随机擦除与变换对比度的方式对原图像数据进行增强;按照比例将数据集拆分为训练集和测试集,并将其全部转化为二进制文件,添加样本标签,将转换得到的TFRcords文件作为ResNet模型数据输入;进行训练ResNet模型得到图像数据的高层特征表示,即图像特征。
具体地,基于Transformer模型对中文描述性数据进行编码,得到文本数据的高层特征表示,包括:
通过分词去词的方法和采用Bert模型处理进行文本预处理,得到文本向量化表示;根据任务的分类标签构建每个类别的描述文本,将Transformer模型的编码器作为一个特征提取器,对中文描述性数据进行特征提取,以获取中文描述性数据的内部信息,得到中文描述性数据的高层特征表示,即第一文本特征。
需要说明的是,在本发明中,所述文本编码器和所述视觉编码器是以图像样本和对应所述图像样本的中文描述性文本样本构成的训练集对图像特征和第一文本特征进行对比学习训练得到。
训练适配中文的文本编码器和视觉编码器的核心在于收集大量的图像文本对(包括图像样本和对应所述图像样本的中文描述性文本样本)数据集,它的来源可以是网上开源的数据集也可以是用户自己收集的相关数据。采用对比学习训练这种方法的主要原因是加速网络的训练,如果采用生成式的方法,训练的难度比较大,模型不容易收敛;而采用对比学习训练配对的方式,降低了训练的难度。在得到了上述适配中文的文本编码器和视觉编码器的模型之后,就可以进行中文小样本学习了。
步骤S320,通过初始模型中的上下文感知层利用多头注意力机制使得图像特征为第一文本特征添加注意力,得到第二文本特征;
请参阅图4,图4为本发明的一示例性实施例示出的利用多头注意力机制使得图像特征为第一文本特征添加注意力的流程图。在图2中,利用多头注意力机制使得图像特征为第一文本特征添加注意力,包括:
步骤S410,对图像特征进行全局池化处理,得到全局特征;
步骤S420,将图像特征和全局特征进行特征融合,得到第一融合特征;
步骤S430,将第一融合特征输入到多头注意力网络中,得到第二融合特征;
步骤S440,将图像特征和第二融合特征进行特征融合,得到第二文本特征。
为了更好地说明上下文提示感知层,对视觉编码器的进行更为细致的描述。不失一般性,这里以残差网络ResNet为例,在视觉编码器编码的过程中,总共会经历四个阶段,然后记这些特征图为在CLIP中,它引入了额外的注意力池化层。残差网络Resnet实现是多个重复的模块进行叠加,每个模块都会让其特征图的大小降低一倍,浅层的特征提取是纹理等局部特征,越往后的阶段,提取的特征偏语义特征了。具体来说,在得到第4阶段的特征/>之后,然后网络在通过一个全局池化层得到全局特征/>这里H4、W4、C后网分别是第4阶段的图像高度、宽度和通道数。然后,将图像特征和全局特征进行特征融合,得到第一融合特征/>之后将融合特征送入到多头注意力网络中,得到第二融合特征,即/>
在原始的CLIP中,仅利用了经过全局池化的特征而没有利用到没有经过全局池化的特征z,然而没有经过全局池化的特征z具有一定的意义,主要是它保留了较好的空间信息。最后利用一个Transformer的解码器来融合原始的文本编码器得到的输出文本特征w和第二融合特征[z,z],同时借助残差学习的思想:transdecoder
其中控制着残差的比例。这样就介绍了上下文感知层了,它主要是为了增加视觉编码器和文本编码器之间的交互,利用视觉语义进一步提高文本语义的效果。
步骤S330,通过初始模型中的相似度度量层计算第二文本特征与图像特征之间的相似度;
中文描述性数据经过文本编码器进行特征提取后,得到多个文本特征,将每一个文本特征与图像特征分别进行相似度计算,具体可以采用余弦相似度进行表示。
步骤S340,基于相似度构建损失函数,并根据损失函数对初始模型进行迭代训练,得到图像分类模型。
在一实施例中,损失函数包括交叉熵损失函数和散度损失函数,其中,交叉熵损失函数用于约束第一预测分类值和真实分类值,散度损失函数用于约束第一预测分类值和第二预测分类值,第二预测分类值为零样本预测分类值。
损失函数在网络的训练过程中起了至关重要的作用。这里的损失函数由两部分构成,分别是利用交叉熵损失约束第一预测分类值和真实分类值,以及KL损失函数约束第一预测分类值和第二预测分类值。在本方法中,网络主要是去更新可学习语义向量t,首先定义交叉熵损失函数:
y代表着真实标签的one-hot编码。
然后在引入CLIP零样本推理结果pzs(wi|x),然后计算p(ti|x)和pzs(wi|x)的KL(Kullback-Leibler Divergence)散度损失函数:
最后再将两个损失函数loss结合到一起优化:
应理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
综上,本发明主要是利用CLIP作为算法基本框架,采用双塔架构,即视觉编码器和文本编码器。输入由两个部分构成:图像和文本。视觉编码器可以选择残差网络(ResNet)或者Transformer结构,文本编码器常用Transformer架构,比如Bert。最早的CLIP工作只是单纯的利用图像类别信息,然后再使用提示学习(prompt learning)方法将类别变成一个句子,比如“aphoto of{class_name}”。当零样本推理(也就是没有样本参与训练)效果不好的时候,我们需要利用少量样本(比如1个样本、2个样本等)优化特定场景下的模型。
在不同的数据集下,使用生硬的、静态的提示学习模板不一定是最佳的选择。因此,引入了“可学习语义向量”,它是连续的、可被学习的。可学习语义向量和图像原本的类别一起构成了文本编码器的输入。同时,也提出了上下文感知,它的作用主要是利用视觉语义来进一步提高文本语义的效果,这里简单使用Transformer架构的解码器部分。在得到图像特征和一系列文本特征之后,先进行L2归一化操作,然后计算图像特征和文本特征的相似度。在得到预测值之后,使用两种损失函数优化模型,分别是利用交叉熵损失约束预测值和真实值,以及KL损失约束预测值和零样本预测值,通过两个损失函数构成总的损失函数,然后根据损失函数对初始模型进行迭代训练,最终得到图像分类模型。
本发明通过在训练集中的中文描述性文本样本中加入了可学习语义向量,通过可学习语义向量的学习,提升了图像分类模型的性能,从而使得通过利用该图像分类模型对图像进行分类,分类更加准确。
图5是本发明的一示例性实施例示出的图像分类模型的训练装置的框图。该装置可以应用于图1所示的实施环境,并具体配置在服务器或终端设备中。该图像分类模型的训练装置也可以适用于其它的示例性实施环境,并具体配置在其它设备中,本实施例不对该图像分类模型的训练装置所适用的实施环境进行限制。
如图5所示,本发明还提供一种图像分类模型的训练装置,装置包括:
数据获取模块510,用于获取由多个训练样本构成的训练集,每个训练样本包括图像数据和对应图像数据的中文描述性数据;中文描述性数据包括具有图像类别表示的可学习语义向量;
训练模块520,用于基于小样本学习,利用训练集对初始模型进行训练,得到图像分类模型。
需要说明的是,上述实施例所提供的图像分类模型的训练装置与上述实施例所提供的图像分类模型的训练方法属于同一构思,其中各个模块和单元执行操作的具体方式已经在图像分类模型的训练方法实施例中进行了详细描述,此处不再赘述。上述实施例所提供的在实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将装置的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能,本处也不对此进行限制。
请参阅图6,图6为本发明的一示例性实施例示出的一种图像分类方法流程图,该分类方法包括:
步骤S610,获取待分类图像;
步骤S620,将待分类图像输入到的通过如图2所示的方法训练得到的图像分类模型中,以图像分类模型的输出作为待分类图像的类别。
请参阅图7,图7为本发明的一示例性实施例示出的一种图像分类装置的框图,分类装置包括:
图像获取模块710,用于获取待分类图像;
图像分类模块720,用于将待分类图像输入到的图像分类模型中,以图像分类模型的输出作为待分类图像的类别。
本申请实施例还提供了一种设备,该设备可以包括:一个或多个处理器;和其上存储有指令的一个或多个机器可读介质,当由一个或多个处理器执行时,使得设备执行图2的图像分类模型的训练方法。在实际应用中,该设备可以作为终端设备,也可以作为服务器,终端设备的例子可以包括:智能手机、平板电脑、电子书阅读器、MP3(动态影像专家压缩标准语音层面3,Moving Picture Experts Group Audio Layer III)播放器、MP4(动态影像专家压缩标准语音层面4,Moving Picture Experts GroupAudio Layer IV)播放器、膝上型便携计算机、车载电脑、台式计算机、机顶盒、智能电视机、可穿戴设备等等,本申请实施例对于具体的设备不加以限制。
本申请实施例还提供了一种非易失性可读存储介质,该存储介质中存储有一个或多个模块(programs),该一个或多个模块被应用在设备时,可以使得该设备执行本申请实施例的图2中的图像分类模型的训练方法所包含步骤的指令(instructions)。
图8为本申请一实施例提供的终端设备的硬件结构示意图。如图所示,该终端设备可以包括:输入设备1100、第一处理器1101、输出设备1102、第一存储器1103和至少一个通信总线1104。通信总线1104用于实现元件之间的通信连接。第一存储器1103可能包含高速RAM存储器,也可能还包括非易失性存储NVM,例如至少一个磁盘存储器,第一存储器1103中可以存储各种程序,用于完成各种处理功能以及实现本实施例的图像分类模型的训练方法步骤。
可选的,上述第一处理器1101例如可以为中央处理器(Central ProcessingUnit,简称CPU)、应用专用集成电路(ASIC)、数字信号处理器(DSP)、数字信号处理设备(DSPD)、可编程逻辑器件(PLD)、现场可编程门阵列(FPGA)、控制器、微控制器、微处理器或其他电子元件实现,该第一处理器1101通过有线或无线连接耦合到上述输入设备1100和输出设备1102。
可选的,上述输入设备1100可以包括多种输入设备,例如可以包括面向用户的用户接口、面向设备的设备接口、软件的可编程接口、摄像头、传感器中至少一种。可选的,该面向设备的设备接口可以是用于设备与设备之间进行数据传输的有线接口,还可以是用于设备与设备之间进行数据传输的硬件***接口(例如USB接口、串口等);可选的,该面向用户的用户接口例如可以是面向用户的控制按键、用于接收语音输入的语音输入设备以及用户接收用户触摸输入的触摸感知设备(例如具有触摸感应功能的触摸屏、触控板等);可选的,上述软件的可编程接口例如可以是供用户编辑或者修改程序的入口,例如芯片的输入引脚接口或者输入接口等;输出设备1102可以包括显示器、音响等输出设备。
在本实施例中,该终端设备的处理器包括用于执行各设备中各模块的功能,具体功能和技术效果参照上述实施例即可,此处不再赘述。
图9为本申请的一个实施例提供的终端设备的硬件结构示意图。图9是对图8在实现过程中的一个具体的实施例。如图所示,本实施例的终端设备可以包括第二处理器1201以及第二存储器1202。
第二处理器1201执行第二存储器1202所存放的计算机程序代码,实现上述实施例中图2的图像分类模型的训练方法的步骤。
第二存储器1202被配置为存储各种类型的数据以支持在终端设备的操作。这些数据的示例包括用于在终端设备上操作的任何应用程序或方法的指令,例如消息,图片,视频等。第二存储器1202可能包含随机存取存储器(random access memory,简称RAM),也可能还包括非易失性存储器(non-volatile memory),例如至少一个磁盘存储器。
可选地,第二处理器1201设置在处理组件1200中。该终端设备还可以包括:通信组件1203,电源组件1204,多媒体组件1205,语音组件1206,输入/输出接口1207和/或传感器组件1208。终端设备具体所包含的组件等依据实际需求设定,本实施例对此不作限定。
处理组件1200通常控制终端设备的整体操作。处理组件1200可以包括一个或多个第二处理器1201来执行指令,以完成上述图像分类模型的训练方法中的全部或部分步骤。此外,处理组件1200可以包括一个或多个模块,便于处理组件1200和其他组件之间的交互。例如,处理组件1200可以包括多媒体模块,以方便多媒体组件1205和处理组件1200之间的交互。
电源组件1204为终端设备的各种组件提供电力。电源组件1204可以包括电源管理***,一个或多个电源,及其他与为终端设备生成、管理和分配电力相关联的组件。
多媒体组件1205包括在终端设备和用户之间提供的一个输出接口的显示屏。在一些实施例中,显示屏可以包括液晶显示器(LCD)和触摸面板(TP)。如果显示屏包括触摸面板,显示屏可以被实现为触摸屏,以接收来自用户的输入信号。触摸面板包括一个或多个触摸传感器以感测触摸、滑动和触摸面板上的手势。触摸传感器可以不仅感测触摸或滑动动作的边界,而且还检测与触摸或滑动操作相关的持续时间和压力。
语音组件1206被配置为输出和/或输入语音信号。例如,语音组件1206包括一个麦克风(MIC),当终端设备处于操作模式,如语音识别模式时,麦克风被配置为接收外部语音信号。所接收的语音信号可以被进一步存储在第二存储器1202或经由通信组件1203发送。在一些实施例中,语音组件1206还包括一个扬声器,用于输出语音信号。
输入/输出接口1207为处理组件1200和***接口模块之间提供接口,上述***接口模块可以是点击轮,按钮等。这些按钮可包括但不限于:音量按钮、启动按钮和锁定按钮。
传感器组件1208包括一个或多个传感器,用于为终端设备提供各个方面的状态评估。例如,传感器组件1208可以检测到终端设备的打开/关闭状态,组件的相对定位,用户与终端设备接触的存在或不存在。传感器组件1208可以包括接近传感器,被配置用来在没有任何的物理接触时检测附近物体的存在,包括检测用户与终端设备间的距离。在一些实施例中,该传感器组件1208还可以包括摄像头等。
通信组件1203被配置为便于终端设备和其他设备之间有线或无线方式的通信。终端设备可以接入基于通信标准的无线网络,如WiFi,2G或3G,或它们的组合。在一个实施例中,该终端设备中可以包括SIM卡插槽,该SIM卡插槽用于***SIM卡,使得终端设备可以登录GPRS网络,通过互联网与服务器建立通信。
由上可知,在图9实施例中所涉及的通信组件1203、语音组件1206以及输入/输出接口1207、传感器组件1208均可以作为图7实施例中的输入设备的实现方式。
上述实施例仅例示性说明本发明的原理及其功效,而非用于限制本发明。任何熟悉此技术的人士皆可在不违背本发明的精神及范畴下,对上述实施例进行修饰或改变。因此,举凡所属技术领域中具有通常知识者在未脱离本发明所揭示的精神与技术思想下所完成的一切等效修饰或改变,仍应由本发明的权利要求所涵盖。

Claims (12)

1.一种图像分类模型的训练方法,其特征在于,所述训练方法包括:
获取由多个训练样本构成的训练集,每个所述训练样本包括图像数据和对应所述图像数据的中文描述性数据;所述中文描述性数据包括具有图像类别表示的可学习语义向量;
基于小样本学习,利用所述训练集对初始模型进行训练,得到图像分类模型。
2.根据权利要求1所述的图像分类模型的训练方法,其特征在于,所述利用所述训练集对初始模型进行训练,包括:
通过所述初始模型中的特征提取层对所述训练样本进行特征提取,得到第一文本特征和图像特征;
通过所述初始模型中的上下文感知层利用多头注意力机制使得图像特征为第一文本特征添加注意力,得到第二文本特征;
通过所述初始模型中的相似度度量层计算第二文本特征与所述图像特征之间的相似度;
基于所述相似度构建损失函数,并根据所述损失函数对所述初始模型进行迭代训练,得到图像分类模型。
3.根据权利要求2所述的图像分类模型的训练方法,其特征在于,通过特征提取层对所述训练样本进行特征提取,包括:
通过特征提取层中的视觉编码器对所述训练样本中的图像数据进行特征提取,得到图像特征;
通过特征提取层是的文本编码器对所述训练样本中的中文描述性数据进行特征提取,得到第一文本特征;
其中,所述文本编码器和所述视觉编码器是以图像样本和对应所述图像样本的中文描述性文本样本构成的训练集对图像特征和第一文本特征进行对比学习训练得到。
4.根据权利要求3所述的图像分类模型的训练方法,其特征在于,利用多头注意力机制使得图像特征为第一文本特征添加注意力,得到第二文本特征,包括:
对所述图像特征进行全局池化处理,得到全局特征;
将所述图像特征和所述全局特征进行特征融合,得到第一融合特征;
将所述第一融合特征输入到多头注意力网络中,得到第二融合特征;
将所述图像特征和所述第二融合特征进行特征融合,得到第二文本特征。
5.根据权利要求1所述的图像分类模型的训练方法,其特征在于,所述图像类别表示位于所述可学习语义向量的开始位置、中间位置或结尾位置。
6.根据权利要求2所述的图像分类模型的训练方法,其特征在于,所述损失函数包括交叉熵损失函数和散度损失函数,其中,交叉熵损失函数用于约束第一预测分类值和真实分类值,所述散度损失函数用于约束第一预测分类值和第二预测分类值,第二预测分类值为零样本预测分类值。
7.根据权利要求3所述的图像分类模型的训练方法,其特征在于,在训练所述图像分类模型的过程中,固定所述视觉编码器与所述文本编码器的参数,对所述可学习语义向量进行更新。
8.一种图像分类模型的训练装置,其特征在于,所述训练装置包括:
数据获取模块,用于获取由多个训练样本构成的训练集,每个所述训练样本包括图像数据和对应所述图像数据的中文描述性数据;所述中文描述性数据包括具有图像类别表示的可学习语义向量;
训练模块,用于基于小样本学习,利用所述训练集对初始模型进行训练,得到图像分类模型。
9.一种图像分类方法,其特征在于,所述分类方法包括:
获取待分类图像;
将所述待分类图像输入到权利要求1~7任意一项所述的图像分类模型中,以所述图像分类模型的输出作为待分类图像的类别。
10.一种图像分类装置,其特征在于,所述分类装置包括:
图像获取模块,用于获取待分类图像;
图像分类模块,用于将所述待分类图像输入到权利要求1~7任意一项所述的图像分类模型中,以所述图像分类模型的输出作为待分类图像的类别。
11.一种电子设备,其特征在于,包括:
一个或多个处理器;和
其上存储有指令的一个或多个机器可读介质,当所述一个或多个处理器执行时,使得所述设备执行如权利要求1-7中一个或多个所述的图像分类模型的训练方法或权利要求9所述的图像分类方法。
12.一个或多个机器可读介质,其特征在于,其上存储有指令,当由一个或多个处理器执行时,使得设备执行如权利要求1-7中一个或多个所述的图像分类模型的训练方法或权利要求9所述的图像分类方法。
CN202310818502.4A 2023-07-05 2023-07-05 图像分类模型的训练、图像分类方法、装置、机器可读介质及设备 Pending CN116740505A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310818502.4A CN116740505A (zh) 2023-07-05 2023-07-05 图像分类模型的训练、图像分类方法、装置、机器可读介质及设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310818502.4A CN116740505A (zh) 2023-07-05 2023-07-05 图像分类模型的训练、图像分类方法、装置、机器可读介质及设备

Publications (1)

Publication Number Publication Date
CN116740505A true CN116740505A (zh) 2023-09-12

Family

ID=87915055

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310818502.4A Pending CN116740505A (zh) 2023-07-05 2023-07-05 图像分类模型的训练、图像分类方法、装置、机器可读介质及设备

Country Status (1)

Country Link
CN (1) CN116740505A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117689974A (zh) * 2023-11-14 2024-03-12 荣耀终端有限公司 图像分类模型的训练方法、电子设备及可读存储介质

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117689974A (zh) * 2023-11-14 2024-03-12 荣耀终端有限公司 图像分类模型的训练方法、电子设备及可读存储介质

Similar Documents

Publication Publication Date Title
CN113010703B (zh) 一种信息推荐方法、装置、电子设备和存储介质
CN112200062B (zh) 一种基于神经网络的目标检测方法、装置、机器可读介质及设备
CN111026861B (zh) 文本摘要的生成方法、训练方法、装置、设备及介质
TW201712600A (zh) 用於自影像偵測與辨認文字之方法與系統
CN111666416B (zh) 用于生成语义匹配模型的方法和装置
CN106973244A (zh) 使用弱监督为图像配字幕
CN113704388A (zh) 多任务预训练模型的训练方法、装置、电子设备和介质
CN113515942A (zh) 文本处理方法、装置、计算机设备及存储介质
CN112200318B (zh) 一种目标检测方法、装置、机器可读介质及设备
CN113392687A (zh) 视频标题生成方法、装置、计算机设备及存储介质
CN115798459B (zh) 音频处理方法、装置、存储介质及电子设备
CN116740505A (zh) 图像分类模型的训练、图像分类方法、装置、机器可读介质及设备
CN113806588A (zh) 搜索视频的方法和装置
CN117216535A (zh) 推荐文本生成模型的训练方法、装置、设备及介质
CN116541492A (zh) 一种数据处理方法及相关设备
Li et al. [Retracted] Multimedia Data Processing Technology and Application Based on Deep Learning
CN117520498A (zh) 基于虚拟数字人交互处理方法、***、终端、设备及介质
CN117273019A (zh) 对话模型的训练方法、对话生成方法、装置和设备
CN117011875A (zh) 多媒体页面的生成方法、装置、设备、介质和程序产品
CN116861363A (zh) 多模态的特征处理方法、装置、存储介质与电子设备
CN116958851A (zh) 视频时效模型的训练方法、装置、设备及存储介质
CN111611420B (zh) 用于生成图像描述信息的方法和装置
CN114490946A (zh) 基于Xlnet模型的类案检索方法、***及设备
CN113722422A (zh) 模型训练方法、文本标签生成方法、装置、设备及介质
Hammad et al. Characterizing the impact of using features extracted from pre-trained models on the quality of video captioning sequence-to-sequence models

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination