CN113052257A - 一种基于视觉转换器的深度强化学习方法及装置 - Google Patents

一种基于视觉转换器的深度强化学习方法及装置 Download PDF

Info

Publication number
CN113052257A
CN113052257A CN202110393996.7A CN202110393996A CN113052257A CN 113052257 A CN113052257 A CN 113052257A CN 202110393996 A CN202110393996 A CN 202110393996A CN 113052257 A CN113052257 A CN 113052257A
Authority
CN
China
Prior art keywords
reinforcement learning
training sample
training
experience
sample images
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202110393996.7A
Other languages
English (en)
Other versions
CN113052257B (zh
Inventor
金丹
王昭
龙玉婧
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
CETC Information Science Research Institute
Original Assignee
CETC Information Science Research Institute
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by CETC Information Science Research Institute filed Critical CETC Information Science Research Institute
Priority to CN202110393996.7A priority Critical patent/CN113052257B/zh
Publication of CN113052257A publication Critical patent/CN113052257A/zh
Application granted granted Critical
Publication of CN113052257B publication Critical patent/CN113052257B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computing Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Probability & Statistics with Applications (AREA)
  • Image Processing (AREA)

Abstract

本发明属于人工智能技术领域,提供一种基于视觉转换器的深度强化学习方法及装置,方法包括:构建基于视觉转换器的深度强化学习网络结构,视觉转换器包括多层感知器和转换编码器,转换编码器包括多头注意力层和前馈网络;初始化深度强化学习网络权重,根据存储器容量构建经验回放池;通过贪婪策略与运行环境交互,产生经验数据并将其放入经验回放池;当经验回放池中样本数量满足预设值时,从中随机抽取一批训练样本图像,对其预处理后,输入深度强化学习网络进行训练;在深度强化学习网络满足收敛条件时,获取强化学习模型。本发明可填补视觉转换器在强化学习领域应用的空白,提高强化学习方法的可解释性,更有效地进行学习训练。

Description

一种基于视觉转换器的深度强化学习方法及装置
技术领域
本发明属于人工智能技术领域,具体而言,涉及一种基于视觉转换器的深度强化学习方法及装置。
背景技术
近年来,强化学习逐渐成为机器学习领域的研究热点。智能体通过在与环境的交互过程中学习策略来实现回报的最大化或实现某种目标。通过与深度学习方法的结合,深度强化学习方法在许多人工智能任务中取得了突破,例如博弈游戏、机器人控制、群体决策、自动驾驶等。
目前,深度强化学习方法主要包括基于值函数的方法、基于策略梯度的方法和基于Actor-Critic框架的方法。在现有的强化学习网络框架中,所采用的网络结构主要是卷积神经网络和长短时记忆网路。卷积神经网络侧重于局部观测信息的提取,全局观测信息的捕捉能力弱。长短时记忆网络处理序列数据更具有优势,可以学习并长期保存信息,但长短时记忆网络作为一种循环网络结构,无法进行并行训练。
转换器(Transformer)在自然语言处理任务中得到了广泛应用,转换器架构可以避免递归,实现并行计算,通过自注意力机制对输入输出的全局依赖关系进行建模。然而,转换器在强化学习领域中还没有相应的研究。因此,需要提供一种改进的基于视觉转换器的深度强化学习方法。
发明内容
本发明旨在至少解决现有技术中存在的技术问题之一,提供一种基于视觉转换器的深度强化学习方法及装置。
本发明的一个方面,提供一种基于视觉转换器的深度强化学习方法,所述方法包括:
构建基于视觉转换器的深度强化学习网络结构,其中,所述视觉转换器包括多层感知器和转换编码器,所述转换编码器包括多头注意力层和前馈网络;
初始化所述深度强化学习网络的权重,根据存储器的容量大小构建经验回放池;
通过贪婪策略与运行环境进行交互,产生经验数据并将其放入所述经验回放池;
当所述经验回放池中的样本数量满足预设的训练样本数量时,从所述经验回放池中随机抽取一批训练样本图像,对所述训练样本图像进行预处理;
将所述预处理后的训练样本图像输入所述深度强化学习网络进行训练;
在所述深度强化学习网络满足收敛条件时,获取强化学习模型。
在一些实施方式中,所述通过贪婪策略与运行环境进行交互,产生经验数据并将其放入所述经验回放池,包括:
通过ε-greedy策略与运行环境进行交互,获取经验数据(s,a,r,s′)并将其放入所述经验回放池,其中,s为当前时刻的观测量,a为当前时刻动作,r为环境返回的回报,s'为下一时刻的观测量。
在一些实施方式中,所述当所述经验回放池中的样本数量满足预设的训练样本数量时,从所述经验回放池中随机抽取一批训练样本图像,对所述训练样本图像进行预处理,包括:
当所述经验回放池中的样本数量满足预设的训练样本数量m时,从所述经验回放池中随机抽取数量为batch大小的训练样本图像,对尺寸大小为H*W的训练样本图像进行预处理,根据所述训练样本图像的大小将其分成N个色块,每个色块的尺寸大小为P*P,其中,H为所述训练样本图像的高度,W为所述训练样本图像的宽度,N=H*W/P2
使用线性投影矩阵将输入的t-2时刻、t-1时刻、t时刻图像中的每个色块X进行平化,得到映射后的D维向量X1=Embedding(X),并向其添加位置嵌入PositionEncoding和时序嵌入SequenceEncoding,以得到色块向量X2=X1+PositionEncoding+SequenceEncoding;
将状态动作价值占位符QvalueToken通过学习参数的方式与所述色块向量X2进行拼接,得到X3=Concat(X2,QvalueToken),之后将处理后的数据输入所述视觉转换器,通过所述视觉转换器输出动作状态值Xoutput,其中,
Xoutput=MLP(Xhidden),
Xhidden=LayerNorm(X_attention+FeedForward(Xattention)),
Xattention=LayerNorm(X3+SelfAttention(X3WQ,X3WK,X3WV)),
其中,MLP为多层感知器,Xhidden为转换编码器的输出,FeedForward为由两层线性映射和激活函数组成的前馈网络,Xattention为多头注意力层的输出,SelfAttention为自注意力层,WQ、WK、WV分别为线性映射的网络权重。
在一些实施方式中,所述将所述预处理后的训练样本图像输入所述深度强化学习网络进行训练,包括:
依据均方误差损失函数L对所述深度强化学习网络进行训练,其中,L=E[r+γmaxa′Q(s′,a′;θ-)-Q(s,a;θ)]2,更新深度强化学习网络权重,
Figure BDA0003017834960000031
其中,E为数学期望,a为当前时刻动作,a′为下一时刻动作,α为学习率,γ为折扣系数,Q(s,a;θ)为当前值神经网络的Q值,Q(s′,a′;θ-)为目标值神经网络的Q值,θ和θ-分别为当前值神经网络的参数和目标值神经网络的参数,θ′为更新后的值神经网络的参数。
本发明的另一个方面,提供一种基于视觉转换器的深度强化学习装置,所述装置包括构建模块、数据采集模块、输入模块、训练模块和获取模块:
所述构建模块,用于构建基于视觉转换器的深度强化学习网络结构,其中,所述视觉转换器包括多层感知器和转换编码器,所述转换编码器包括多头注意力层和前馈网络,初始化所述深度强化学习网络的权重,根据存储器的容量大小构建经验回放池;
所述数据采集模块,用于通过贪婪策略与运行环境进行交互,产生经验数据并将其放入所述经验回放池;
所述输入模块,用于当所述经验回放池中的样本数量满足预设的训练样本数量时,从所述经验回放池中随机抽取一批训练样本图像,对所述训练样本图像进行预处理,并将所述预处理后的训练样本图像输入所述训练模块;
所述训练模块,用于利用所述预处理后的训练样本图像对所述深度强化学习网络进行训练;
所述获取模块,用于在所述深度强化学习网络满足收敛条件时,获取强化学习模型。
在一些实施方式中,所述数据采集模块具体用于:
通过ε-greedy策略与运行环境进行交互,获取经验数据(s,a,r,s′)并将其放入所述经验回放池,其中,s为当前时刻的观测量,a为当前时刻动作,r为环境返回的回报,s'为下一时刻的观测量。
在一些实施方式中,所述输入模块具体用于:
当所述经验回放池中的样本数量满足预设的训练样本数量m时,从所述经验回放池中随机抽取数量为batch大小的训练样本图像,对尺寸大小为H*W的训练样本图像进行预处理,根据所述训练样本图像的大小将其分成N个色块,每个色块的尺寸大小为P*P,其中,H为所述训练样本图像的高度,W为所述训练样本图像的宽度,N=H*W/P2
使用线性投影矩阵将输入的t-2时刻、t-1时刻、t时刻图像中的每个色块X进行平化,得到映射后的D维向量X1=Embedding(X),并向其添加位置嵌入PositionEncoding和时序嵌入SequenceEncoding,以得到色块向量X2=X1+PositionEncoding+SequenceEncoding;
将状态动作价值占位符QvalueToken通过学习参数的方式与所述色块向量X2进行拼接,得到X3=Concat(X2,QvalueToken),之后将处理后的数据输入所述视觉转换器,通过所述视觉转换器输出动作状态值Xoutput,其中,
Xoutput=MLP(Xhidden),
Xhidden=LayerNorm(X_attention+FeedForward(Xattention)),
Xattention=LayerNorm(X3+SelfAttention(X3WQ,X3WK,X3WV)),
其中,MLP为多层感知器,Xhidden为转换编码器的输出,FeedForward为由两层线性映射和激活函数组成的前馈网络,Xattention为多头注意力层的输出,SelfAttention为自注意力层,WQ、WK、WV分别为线性映射的网络权重。
在一些实施方式中,所述训练模块具体用于:
依据均方误差损失函数L对所述深度强化学习网络进行训练,其中,L=E[r+γmaxa′Q(s′,a′;θ-)-Q(s,a;θ)]2,更新深度强化学习网络权重,
Figure BDA0003017834960000051
其中,E为数学期望,a为当前时刻动作,a′为下一时刻动作,α为学习率,γ为折扣系数,Q(s,a;θ)为当前值神经网络的Q值,Q(s′,a′;θ-)为目标值神经网络的Q值,θ和θ-分别为当前值神经网络的参数和目标值神经网络的参数,θ′为更新后的值神经网络的参数。
本发明的另一个方面,提供一种电子设备,所述电子设备包括:
一个或多个处理器;
存储单元,用于存储一个或多个程序,当所述一个或多个程序被所述一个或多个处理器执行时,能使得所述一个或多个处理器实现前文记载的所述的方法。
本发明的另一个方面,提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时能实现根据前文记载的所述的方法。
本发明的基于视觉转换器的深度强化学习方法及装置,通过将视觉转换器引入深度强化学习网络,填补了视觉转换器在强化学习领域应用的空白,提高了强化学习方法的可解释性,能够更有效地进行学习训练,可应用于使用强化学习算法的场景,如游戏、机器人控制等。
附图说明
图1为本发明一实施例的电子设备的组成示意框图;
图2为本发明另一实施例的基于视觉转换器的深度强化学习方法的流程图;
图3为本发明另一实施例的基于视觉转换器的深度强化学习网络的结构示意图;
图4为本发明另一实施例的转换编码器的结构示意图;
图5为本发明另一实施例的基于视觉转换器的深度强化学习装置的结构示意图。
具体实施方式
为使本领域技术人员更好地理解本发明的技术方案,下面结合附图和具体实施方式对本发明作进一步详细描述。
首先,参照图1来描述用于实现本发明实施例的装置及方法的示例电子设备。
如图1所示,电子设备200包括一个或多个处理器210、一个或多个存储装置220、一个或多个输入装置230、一个或多个输出装置240等,这些组件通过总线***250和/或其他形式的连接机构互连。应当注意,图1所示的电子设备的组件和结构只是示例性的,而非限制性的,根据需要,电子设备也可以具有其他组件和结构。
处理器210可以是由多(众)核架构的芯片组成的神经网络处理器,也可以是单独的中央处理单元(CPU),或者,也可以是中央处理单元+多核神经网络处理器阵列或者具有数据处理能力和/或指令执行能力的其他形式的处理单元,并且可以控制电子设备200中的其他组件以执行期望的功能。
存储装置220可以包括一个或多个计算机程序产品,所述计算机程序产品可以包括各种形式的计算机可读存储介质,例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(RAM)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(ROM)、硬盘、闪存等。在所述计算机可读存储介质上可以存储一个或多个计算机程序指令,处理器可以运行所述程序指令,以实现下文所述的本发明实施例中(由处理器实现)的客户端功能以及/或者其他期望的功能。在所述计算机可读存储介质中还可以存储各种应用程序和各种数据,例如,所述应用程序使用和/或产生的各种数据等。
输入装置230可以是用户用来输入指令的装置,并且可以包括键盘、鼠标、麦克风和触摸屏等中的一个或多个。
输出装置240可以向外部(例如用户)输出各种信息(例如图像或声音),并且可以包括显示器、扬声器等中的一个或多个。
下面,将参考图2描述根据本发明一实施例的基于视觉转换器的深度强化学习方法。
示例性的,如图2所示,本实施例提供一种基于视觉转换器的深度强化学习方法S100,方法S100包括:
S110、构建基于视觉转换器的深度强化学习网络结构,其中,所述视觉转换器包括多层感知器和转换编码器,所述转换编码器包括多头注意力层和前馈网络。
具体地,可以基于强化学习运行环境,定义状态空间、动作空间及奖励函数,构建基于视觉转换器的深度强化学习网络结构。其中,如图3所示,视觉转换器包括一个多层感知器和一个转换编码器。如图4所示,转换编码器包括多头注意力层和前馈网络。
S120、初始化所述深度强化学习网络的权重,根据存储器的容量大小构建经验回放池。
具体地,可以将深度强化学习网络的各个权重进行初始化,根据存储器的容量大小建立经验回放池。
S130、通过贪婪策略与运行环境进行交互,产生经验数据并将其放入所述经验回放池。
具体地,可以通过贪婪策略与强化学习运行环境进行交互,在交互过程中产生经验数据,并将该经验数据放入经验回放池。
S140、当所述经验回放池中的样本数量满足预设的训练样本数量时,从所述经验回放池中随机抽取一批训练样本图像,对所述训练样本图像进行预处理。
具体地,当经验回放池中的样本数量满足预设的训练样本数量时,可以从经验回放池中随机抽取一批训练样本图像,然后根据实际需要,对这些训练样本图像进行预处理。
需要说明的是,预设的训练样本数量可以是对深度强化学习网络进行一次训练所需的最小训练样本数量,也可以是根据实际需要设定的任一训练样本数量,本领域技术人员可以按需进行选择,本实施例对此并不限制。
S150、将所述预处理后的训练样本图像输入所述深度强化学习网络进行训练。
具体地,可以将预处理后的训练样本图像作为输入,对深度强化学习网络进行训练。
S160、在所述深度强化学习网络满足收敛条件时,获取强化学习模型。
具体地,在对深度强化学习网络进行训练的过程中,当深度强化学习网络满足收敛条件时,获取当前的强化学习模型,以作为最终的强化学习模型。
本实施例的基于视觉转换器的深度强化学习方法,通过将视觉转换器引入深度强化学习网络,填补了视觉转换器在强化学习领域应用的空白,提高了强化学习方法的可解释性,能够更有效地进行学习训练,可应用于使用强化学习算法的场景,如游戏、机器人控制等。
示例性的,所述通过贪婪策略与运行环境进行交互,产生经验数据并将其放入所述经验回放池,包括:
通过ε-greedy策略与运行环境进行交互,在进行交互时,输出动作会以ε的概率从所有动作中随机抽取一个动作,以1-ε的概率抽取价值最大的动作,获取经验数据(s,a,r,s′)并将其放入经验回放池,其中,s为当前时刻的观测量,a为当前时刻动作,r为环境返回的回报,s'为下一时刻的观测量。
示例性的,所述当所述经验回放池中的样本数量满足预设的训练样本数量时,从所述经验回放池中随机抽取一批训练样本图像,对所述训练样本图像进行预处理,包括:
当经验回放池中的样本数量满足预设的训练样本数量m时,从经验回放池中随机抽取数量为batch大小的训练样本图像,对尺寸大小为H*W的训练样本图像进行预处理。根据训练样本图像的大小将其分成N个色块,每个色块的尺寸大小为P*P,其中,H为训练样本图像的高度,W为训练样本图像的宽度,N=H*W/P2
使用线性投影矩阵将输入的t-2时刻、t-1时刻、t时刻图像中的每个色块X进行平化,得到映射后的D维向量X1=Embedding(X),具体操作如pytorch深度学习框架中的全连接层torch.nn.Linear,并向D维向量X1添加位置嵌入PositionEncoding和时序嵌入SequenceEncoding,以得到色块向量X2=X1+PositionEncoding+SequenceEncoding,具体操作如pytorch深度学习框架中的模块参数设置torch.nn.Parameter。
将状态动作价值占位符QvalueToken通过学习参数的方式与色块向量X2进行拼接,得到X3=Concat(X2,QvalueToken),具体操作如pytorch深度学习框架中的torch.nn.Identity函数。之后将处理后的数据输入视觉转换器,通过视觉转换器输出动作状态值Xoutput,其中,
Xoutput=MLP(Xhidden),
Xhidden=LayerNorm(X_attention+FeedForward(Xattention)),
Xattention=LayerNorm(X3+SelfAttention(X3WQ,X3WK,X3WV)),
其中,MLP为多层感知器,Xhidden为转换编码器的输出,FeedForward为由两层线性映射和激活函数组成的前馈网络,Xattention为多头注意力层的输出,SelfAttention为自注意力层,即多头注意力层,WQ、WK、WV分别为线性映射的网络权重。
本实施例的基于视觉转换器的深度强化学习方法,通过视觉转换器的注意力机制,能够进一步提高强化学习方法的可解释性,并在提取局部观测信息的同时,进一步学习有用的全局观测信息,从而更好地捕捉全局信息。另外,本实施例通过利用视觉转换器的时序编码,使得深度强化学习网络可以利用过去时刻的观测信息,从而能够更有效地进行学习训练。
示例性的,所述将所述预处理后的训练样本图像输入所述深度强化学习网络进行训练,包括:
依据均方误差损失函数L对深度强化学习网络进行训练,其中,L=E[r+γmaxa′Q(s′,a′;θ-)-Q(s,a;θ)]2,更新深度强化学习网络权重,
Figure BDA0003017834960000091
其中,E为数学期望,a为当前时刻动作,a′为下一时刻动作,α为学习率,γ为折扣系数,Q(s,a;θ)为当前值神经网络的Q值,Q(s′,a′;θ-)为目标值神经网络的Q值,θ和θ-分别为当前值神经网络的参数和目标值神经网络的参数,θ′为更新后的值神经网络的参数。
本实施例的基于视觉转换器的深度强化学习方法,可以通过并行的方式对深度强化学习网络进行训练,从而加快深度强化学习网络的收敛速度。
本发明的另一个方面,提供一种基于视觉转换器的深度强化学习装置。
示例性的,如图5所示,本实施例提供一种基于视觉转换器的深度强化学习装置100,装置100包括构建模块110、数据采集模块120、输入模块130、训练模块140和获取模块150。该装置100可以应用于前文记载的方法,下述装置中未提及的具体内容可以参考前文相关记载,在此不作赘述。
构建模块110用于构建基于视觉转换器的深度强化学习网络结构,定义状态空间、动作空间及奖励函数,其中,所述视觉转换器包括多层感知器和转换编码器,所述转换编码器包括多头注意力层和前馈网络,初始化所述深度强化学习网络的权重,根据存储器的容量大小构建经验回放池;
数据采集模块120用于通过贪婪策略与运行环境进行交互,产生经验数据并将其放入所述经验回放池;
输入模块130用于当所述经验回放池中的样本数量满足预设的训练样本数量时,从所述经验回放池中随机抽取一批训练样本图像,对所述训练样本图像进行预处理,并将所述预处理后的训练样本图像输入所述训练模块140;
训练模块140用于利用所述预处理后的训练样本图像对所述深度强化学习网络进行训练;
获取模块150用于在所述深度强化学习网络满足收敛条件时,获取强化学习模型。
本实施例的基于视觉转换器的深度强化学习装置,通过将视觉转换器引入深度强化学习网络,填补了视觉转换器在强化学习领域应用的空白,提高了强化学习方法的可解释性,能够更有效地进行学习训练,可应用于使用强化学习算法的场景,如游戏、机器人控制等。
示例性的,数据采集模块120具体用于:
通过ε-greedy策略与运行环境进行交互,获取经验数据(s,a,r,s′)并将其放入所述经验回放池,其中,s为当前时刻的观测量,a为当前时刻动作,r为环境返回的回报,s'为下一时刻的观测量。
示例性的,输入模块130具体用于:
当所述经验回放池中的样本数量满足预设的训练样本数量m时,从所述经验回放池中随机抽取数量为batch大小的训练样本图像,对尺寸大小为H*W的训练样本图像进行预处理,根据所述训练样本图像的大小将其分成N个色块,每个色块的尺寸大小为P*P,其中,H为所述训练样本图像的高度,W为所述训练样本图像的宽度,N=H*W/P2
使用线性投影矩阵将输入的t-2时刻、t-1时刻、t时刻图像中的每个色块X进行平化,得到映射后的D维向量X1=Embedding(X),并向其添加位置嵌入PositionEncoding和时序嵌入SequenceEncoding,以得到色块向量X2=X1+PositionEncoding+SequenceEncoding;
将状态动作价值占位符QvalueToken通过学习参数的方式与所述色块向量X2进行拼接,得到X3=Concat(X2,QvalueToken),之后将处理后的数据输入所述视觉转换器,通过所述视觉转换器输出动作状态值Xoutput,其中,
Xoutput=MLP(Xhidden),
Xhidden=LayerNorm(X_attention+FeedForward(Xattention)),
Xattention=LayerNorm(X3+SelfAttention(X3WQ,X3WK,X3WV)),
其中,MLP为多层感知器,Xhidden为转换编码器的输出,FeedForward为由两层线性映射和激活函数组成的前馈网络,Xattention为多头注意力层的输出,SelfAttention为自注意力层,即多头注意力层,WQ、WK、WV分别为线性映射的网络权重。
本实施例的基于视觉转换器的深度强化学习装置,通过视觉转换器的注意力机制,能够进一步提高强化学习方法的可解释性,并在提取局部观测信息的同时,进一步学习有用的全局观测信息,从而更好地捕捉全局信息。另外,本实施例通过利用视觉转换器的时序编码,使得深度强化学习网络可以利用过去时刻的观测信息,从而能够更有效地进行学习训练。
示例性的,训练模块140具体用于:
依据均方误差损失函数L对所述深度强化学习网络进行训练,其中,L=E[r+γmaxa′Q(s′,a′;θ-)-Q(s,a;θ)]2,更新深度强化学习网络权重,
Figure BDA0003017834960000121
其中,E为数学期望,a为当前时刻动作,a′为下一时刻动作,α为学习率,γ为折扣系数,Q(s,a;θ)为当前值神经网络的Q值,Q(s′,a′;θ-)为目标值神经网络的Q值,θ和θ-分别为当前值神经网络的参数和目标值神经网络的参数,θ′为更新后的值神经网络的参数。
本实施例的基于视觉转换器的深度强化学习装置,可以通过并行的方式对深度强化学习网络进行训练,从而加快深度强化学习网络的收敛速度。
本发明的另一个方面,提供一种电子设备,包括:
一个或多个处理器;
存储单元,用于存储一个或多个程序,当所述一个或多个程序被所述一个或多个处理器执行时,能使得所述一个或多个处理器实现根据前文记载的所述的方法。
本发明的另一个方面,提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时能实现根据前文记载的所述的方法。
其中,计算机可读存储介质可以是本发明的装置、设备中所包含的,也可以是单独存在。
其中,计算机可读存储介质可以是任何包含或存储程序的有形介质,其可以是电、磁、光、电磁、红外线、半导体的***、装置、设备,更具体的例子包括但不限于:具有一个或多个导线的相连、便携式计算机磁盘、硬盘、光纤、随机访问存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、便携式紧凑磁盘只读存储器(CD-ROM)、光存储器件、磁存储器件,或它们任意合适的组合。
其中,计算机可读存储介质也可以包括在基带中或作为载波一部分传播的数据信号,其中承载了计算机可读的程序代码,其具体的例子包括但不限于电磁信号、光信号,或它们任意合适的组合。
可以理解的是,以上实施方式仅仅是为了说明本发明的原理而采用的示例性实施方式,然而本发明并不局限于此。对于本领域内的普通技术人员而言,在不脱离本发明的精神和实质的情况下,可以做出各种变型和改进,这些变型和改进也视为本发明的保护范围。

Claims (10)

1.一种基于视觉转换器的深度强化学习方法,其特征在于,所述方法包括:
构建基于视觉转换器的深度强化学习网络结构,其中,所述视觉转换器包括多层感知器和转换编码器,所述转换编码器包括多头注意力层和前馈网络;
初始化所述深度强化学习网络的权重,根据存储器的容量大小构建经验回放池;
通过贪婪策略与运行环境进行交互,产生经验数据并将其放入所述经验回放池;
当所述经验回放池中的样本数量满足预设的训练样本数量时,从所述经验回放池中随机抽取一批训练样本图像,对所述训练样本图像进行预处理;
将所述预处理后的训练样本图像输入所述深度强化学习网络进行训练;
在所述深度强化学习网络满足收敛条件时,获取强化学习模型。
2.根据权利要求1所述的方法,其特征在于,所述通过贪婪策略与运行环境进行交互,产生经验数据并将其放入所述经验回放池,包括:
通过ε-greedy策略与运行环境进行交互,获取经验数据(s,a,r,s′)并将其放入所述经验回放池,其中,s为当前时刻的观测量,a为当前时刻动作,r为环境返回的回报,s'为下一时刻的观测量。
3.根据权利要求2所述的方法,其特征在于,所述当所述经验回放池中的样本数量满足预设的训练样本数量时,从所述经验回放池中随机抽取一批训练样本图像,对所述训练样本图像进行预处理,包括:
当所述经验回放池中的样本数量满足预设的训练样本数量m时,从所述经验回放池中随机抽取数量为batch大小的训练样本图像,对尺寸大小为H*W的训练样本图像进行预处理,根据所述训练样本图像的大小将其分成N个色块,每个色块的尺寸大小为P*P,其中,H为所述训练样本图像的高度,W为所述训练样本图像的宽度,N=H*W/P2
使用线性投影矩阵将输入的t-2时刻、t-1时刻、t时刻图像中的每个色块X进行平化,得到映射后的D维向量X1=Embedding(X),并向其添加位置嵌入PositionEncoding和时序嵌入SequenceEncoding,以得到色块向量X2=X1+PositionEncoding+SequenceEncoding;
将状态动作价值占位符QvalueToken通过学习参数的方式与所述色块向量X2进行拼接,得到X3=Concat(X2,QvalueToken),之后将处理后的数据输入所述视觉转换器,通过所述视觉转换器输出动作状态值Xoutput,其中,
Xoutput=MLP(Xhidden),
Xhidden=LayerNorm(Xattention+FeedForward(Xattention)),
Xattention=LayerNorm(X3+SelfAttention(X3WQ,X3WK,X3WV)),
其中,MLP为多层感知器,Xhidden为转换编码器的输出,FeedForward为由两层线性映射和激活函数组成的前馈网络,Xattention为多头注意力层的输出,SelfAttention为自注意力层,WQ、WK、WV分别为线性映射的网络权重。
4.根据权利要求3所述的方法,其特征在于,所述将所述预处理后的训练样本图像输入所述深度强化学习网络进行训练,包括:
依据均方误差损失函数L对所述深度强化学习网络进行训练,其中,L=E[r+γmaxa′Q(s′,a′;θ-)-Q(s,a;θ)]2,更新深度强化学习网络权重,
Figure FDA0003017834950000021
其中,E为数学期望,a为当前时刻动作,a′为下一时刻动作,α为学习率,γ为折扣系数,Q(s,a;θ)为当前值神经网络的Q值,Q(s′,a′;θ-)为目标值神经网络的Q值,θ和θ-分别为当前值神经网络的参数和目标值神经网络的参数,θ′为更新后的值神经网络的参数。
5.一种基于视觉转换器的深度强化学习装置,其特征在于,所述装置包括构建模块、数据采集模块、输入模块、训练模块和获取模块:
所述构建模块,用于构建基于视觉转换器的深度强化学习网络结构,其中,所述视觉转换器包括多层感知器和转换编码器,所述转换编码器包括多头注意力层和前馈网络,初始化所述深度强化学习网络的权重,根据存储器的容量大小构建经验回放池;
所述数据采集模块,用于通过贪婪策略与运行环境进行交互,产生经验数据并将其放入所述经验回放池;
所述输入模块,用于当所述经验回放池中的样本数量满足预设的训练样本数量时,从所述经验回放池中随机抽取一批训练样本图像,对所述训练样本图像进行预处理,并将所述预处理后的训练样本图像输入所述训练模块;
所述训练模块,用于利用所述预处理后的训练样本图像对所述深度强化学习网络进行训练;
所述获取模块,用于在所述深度强化学习网络满足收敛条件时,获取强化学习模型。
6.根据权利要求5所述的装置,其特征在于,所述数据采集模块具体用于:
通过ε-greedy策略与运行环境进行交互,获取经验数据(s,a,r,s′)并将其放入所述经验回放池,其中,s为当前时刻的观测量,a为当前时刻动作,r为环境返回的回报,s'为下一时刻的观测量。
7.根据权利要求6所述的装置,其特征在于,所述输入模块具体用于:
当所述经验回放池中的样本数量满足预设的训练样本数量m时,从所述经验回放池中随机抽取数量为batch大小的训练样本图像,对尺寸大小为H*W的训练样本图像进行预处理,根据所述训练样本图像的大小将其分成N个色块,每个色块的尺寸大小为P*P,其中,H为所述训练样本图像的高度,W为所述训练样本图像的宽度,N=H*W/P2
使用线性投影矩阵将输入的t-2时刻、t-1时刻、t时刻图像中的每个色块X进行平化,得到映射后的D维向量X1=Embedding(X),并向其添加位置嵌入PositionEncoding和时序嵌入SequenceEncoding,以得到色块向量X2=X1+PositionEncoding+SequenceEncoding;
将状态动作价值占位符QvalueToken通过学习参数的方式与所述色块向量X2进行拼接,得到X3=Concat(X2,QvalueToken),之后将处理后的数据输入所述视觉转换器,通过所述视觉转换器输出动作状态值Xoutput,其中,
Xoutput=MLP(Xhidden),
Xhidden=LayerNorm(X_attention+FeedForward(Xattention)),
Xattention=LayerNorm(X3+SelfAttention(X3WQ,X3WK,X3WV)),
其中,MLP为多层感知器,Xhidden为转换编码器的输出,FeedForward为由两层线性映射和激活函数组成的前馈网络,Xattention为多头注意力层的输出,SelfAttention为自注意力层,WQ、WK、WV分别为线性映射的网络权重。
8.根据权利要求7所述的装置,其特征在于,所述训练模块具体用于:
依据均方误差损失函数L对所述深度强化学习网络进行训练,其中,L=E[r+γmaxa′Q(s′,a′;θ-)-Q(s,a;θ)]2,更新深度强化学习网络权重,
Figure FDA0003017834950000041
其中,E为数学期望,a为当前时刻动作,a′为下一时刻动作,α为学习率,γ为折扣系数,Q(s,a;θ)为当前值神经网络的Q值,Q(s′,a′;θ-)为目标值神经网络的Q值,θ和θ-分别为当前值神经网络的参数和目标值神经网络的参数,θ′为更新后的值神经网络的参数。
9.一种电子设备,其特征在于,所述电子设备包括:
一个或多个处理器;
存储单元,用于存储一个或多个程序,当所述一个或多个程序被所述一个或多个处理器执行时,能使得所述一个或多个处理器实现根据权利要求1至4中任一项所述的方法。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时能实现根据权利要求1至4中任一项所述的方法。
CN202110393996.7A 2021-04-13 2021-04-13 一种基于视觉转换器的深度强化学习方法及装置 Active CN113052257B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110393996.7A CN113052257B (zh) 2021-04-13 2021-04-13 一种基于视觉转换器的深度强化学习方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110393996.7A CN113052257B (zh) 2021-04-13 2021-04-13 一种基于视觉转换器的深度强化学习方法及装置

Publications (2)

Publication Number Publication Date
CN113052257A true CN113052257A (zh) 2021-06-29
CN113052257B CN113052257B (zh) 2024-04-16

Family

ID=76519168

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110393996.7A Active CN113052257B (zh) 2021-04-13 2021-04-13 一种基于视觉转换器的深度强化学习方法及装置

Country Status (1)

Country Link
CN (1) CN113052257B (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113469119A (zh) * 2021-07-20 2021-10-01 合肥工业大学 基于视觉转换器和图卷积网络的宫颈细胞图像分类方法
CN115147669A (zh) * 2022-06-24 2022-10-04 北京百度网讯科技有限公司 基于视觉转换器模型的图像处理方法、训练方法和设备

Citations (18)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20150201895A1 (en) * 2012-08-31 2015-07-23 The University Of Chicago Supervised machine learning technique for reduction of radiation dose in computed tomography imaging
CN108288094A (zh) * 2018-01-31 2018-07-17 清华大学 基于环境状态预测的深度强化学习方法及装置
CN109241552A (zh) * 2018-07-12 2019-01-18 哈尔滨工程大学 一种基于多约束目标的水下机器人运动规划方法
US20190124348A1 (en) * 2017-10-19 2019-04-25 Samsung Electronics Co., Ltd. Image encoder using machine learning and data processing method of the image encoder
US20190258671A1 (en) * 2016-10-28 2019-08-22 Vilynx, Inc. Video Tagging System and Method
CN110286161A (zh) * 2019-03-28 2019-09-27 清华大学 基于自适应增强学习的主变压器故障诊断方法
CN110945495A (zh) * 2017-05-18 2020-03-31 易享信息技术有限公司 基于神经网络的自然语言查询到数据库查询的转换
CN111126282A (zh) * 2019-12-25 2020-05-08 中国矿业大学 一种基于变分自注意力强化学习的遥感图像内容描述方法
CN111461321A (zh) * 2020-03-12 2020-07-28 南京理工大学 基于Double DQN的改进深度强化学习方法及***
CN111597830A (zh) * 2020-05-20 2020-08-28 腾讯科技(深圳)有限公司 基于多模态机器学习的翻译方法、装置、设备及存储介质
CN111666500A (zh) * 2020-06-08 2020-09-15 腾讯科技(深圳)有限公司 文本分类模型的训练方法及相关设备
CN111709398A (zh) * 2020-07-13 2020-09-25 腾讯科技(深圳)有限公司 一种图像识别的方法、图像识别模型的训练方法及装置
KR20200132665A (ko) * 2019-05-17 2020-11-25 삼성전자주식회사 집중 레이어를 포함하는 생성기를 기반으로 예측 이미지를 생성하는 장치 및 그 제어 방법
US20200379461A1 (en) * 2019-05-29 2020-12-03 Argo AI, LLC Methods and systems for trajectory forecasting with recurrent neural networks using inertial behavioral rollout
CN112084314A (zh) * 2020-08-20 2020-12-15 电子科技大学 一种引入知识的生成式会话***
CN112261725A (zh) * 2020-10-23 2021-01-22 安徽理工大学 一种基于深度强化学习的数据包传输智能决策方法
US20210073995A1 (en) * 2019-09-11 2021-03-11 Nvidia Corporation Training strategy search using reinforcement learning
CN112488306A (zh) * 2020-12-22 2021-03-12 中国电子科技集团公司信息科学研究院 一种神经网络压缩方法、装置、电子设备和存储介质

Patent Citations (18)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20150201895A1 (en) * 2012-08-31 2015-07-23 The University Of Chicago Supervised machine learning technique for reduction of radiation dose in computed tomography imaging
US20190258671A1 (en) * 2016-10-28 2019-08-22 Vilynx, Inc. Video Tagging System and Method
CN110945495A (zh) * 2017-05-18 2020-03-31 易享信息技术有限公司 基于神经网络的自然语言查询到数据库查询的转换
US20190124348A1 (en) * 2017-10-19 2019-04-25 Samsung Electronics Co., Ltd. Image encoder using machine learning and data processing method of the image encoder
CN108288094A (zh) * 2018-01-31 2018-07-17 清华大学 基于环境状态预测的深度强化学习方法及装置
CN109241552A (zh) * 2018-07-12 2019-01-18 哈尔滨工程大学 一种基于多约束目标的水下机器人运动规划方法
CN110286161A (zh) * 2019-03-28 2019-09-27 清华大学 基于自适应增强学习的主变压器故障诊断方法
KR20200132665A (ko) * 2019-05-17 2020-11-25 삼성전자주식회사 집중 레이어를 포함하는 생성기를 기반으로 예측 이미지를 생성하는 장치 및 그 제어 방법
US20200379461A1 (en) * 2019-05-29 2020-12-03 Argo AI, LLC Methods and systems for trajectory forecasting with recurrent neural networks using inertial behavioral rollout
US20210073995A1 (en) * 2019-09-11 2021-03-11 Nvidia Corporation Training strategy search using reinforcement learning
CN111126282A (zh) * 2019-12-25 2020-05-08 中国矿业大学 一种基于变分自注意力强化学习的遥感图像内容描述方法
CN111461321A (zh) * 2020-03-12 2020-07-28 南京理工大学 基于Double DQN的改进深度强化学习方法及***
CN111597830A (zh) * 2020-05-20 2020-08-28 腾讯科技(深圳)有限公司 基于多模态机器学习的翻译方法、装置、设备及存储介质
CN111666500A (zh) * 2020-06-08 2020-09-15 腾讯科技(深圳)有限公司 文本分类模型的训练方法及相关设备
CN111709398A (zh) * 2020-07-13 2020-09-25 腾讯科技(深圳)有限公司 一种图像识别的方法、图像识别模型的训练方法及装置
CN112084314A (zh) * 2020-08-20 2020-12-15 电子科技大学 一种引入知识的生成式会话***
CN112261725A (zh) * 2020-10-23 2021-01-22 安徽理工大学 一种基于深度强化学习的数据包传输智能决策方法
CN112488306A (zh) * 2020-12-22 2021-03-12 中国电子科技集团公司信息科学研究院 一种神经网络压缩方法、装置、电子设备和存储介质

Non-Patent Citations (8)

* Cited by examiner, † Cited by third party
Title
DOSOVITSKIY, ALEXEY, ET AL: "\"An image is worth 16x16 words: Transformers for image recognition at scale\"", 《ARXIV》, pages 1 - 4 *
HAOYI ZHOU, 等: ""Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting"", 《ARXIV》, 28 March 2021 (2021-03-28), pages 1 - 15 *
IKER PENG: ""强化学习在视觉上的应用(RL for computer Vision)"", pages 1 - 9, Retrieved from the Internet <URL:《https://zhuanlan.zhihu.com/p/51202503》> *
J. KULHÁNEK等: ""Visual Navigation in Real-World Indoor Environments Using End-to-End Deep Reinforcement Learning"", 《IEEE ROBOTICS AND AUTOMATION LETTERS》, vol. 6, no. 3, 23 March 2021 (2021-03-23), pages 4345 - 4352 *
人工智能学术前沿: ""用于时间序列预测的深度Transformer模型:流感流行病例"", pages 1 - 6, Retrieved from the Internet <URL:《https://zhuanlan.zhihu.com/p/151423371》> *
李峰: ""深度强化学习必看经典论文:DQN,DDQN,Prioritized,Dueling,Rainbow"", pages 1 - 2, Retrieved from the Internet <URL:《https://zhuanlan.zhihu.com/p/337553995》> *
李飞雨: ""基于强化学习和机器翻译质量评估的中朝机器翻译研究"", 《计算机应用研究》 *
郝燕龙: ""基于密集卷积神经网络特征提取的图像描述模型研究"", 《中国优秀硕士学位论文全文数据库 信息科技辑》, no. 9, 15 September 2019 (2019-09-15), pages 138 - 1158 *

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113469119A (zh) * 2021-07-20 2021-10-01 合肥工业大学 基于视觉转换器和图卷积网络的宫颈细胞图像分类方法
CN113469119B (zh) * 2021-07-20 2022-10-04 合肥工业大学 基于视觉转换器和图卷积网络的宫颈细胞图像分类方法
CN115147669A (zh) * 2022-06-24 2022-10-04 北京百度网讯科技有限公司 基于视觉转换器模型的图像处理方法、训练方法和设备
CN115147669B (zh) * 2022-06-24 2023-04-18 北京百度网讯科技有限公司 基于视觉转换器模型的图像处理方法、训练方法和设备

Also Published As

Publication number Publication date
CN113052257B (zh) 2024-04-16

Similar Documents

Publication Publication Date Title
US20210390653A1 (en) Learning robotic tasks using one or more neural networks
US11373087B2 (en) Method and apparatus for generating fixed-point type neural network
CN109464803B (zh) 虚拟对象控制、模型训练方法、装置、存储介质和设备
KR102387570B1 (ko) 표정 생성 방법, 표정 생성 장치 및 표정 생성을 위한 학습 방법
CN113039555B (zh) 在视频剪辑中进行动作分类的方法、***及存储介质
CN110476173B (zh) 利用强化学习的分层设备放置
CN116415654A (zh) 一种数据处理方法及相关设备
CN112052948B (zh) 一种网络模型压缩方法、装置、存储介质和电子设备
CN108665506A (zh) 图像处理方法、装置、计算机存储介质及服务器
CN110795549B (zh) 短文本对话方法、装置、设备及存储介质
CN111292262B (zh) 图像处理方法、装置、电子设备以及存储介质
CN113052257A (zh) 一种基于视觉转换器的深度强化学习方法及装置
CN109858046A (zh) 利用辅助损失来学习神经网络中的长期依赖性
CN112216307A (zh) 语音情感识别方法以及装置
CN111340190A (zh) 构建网络结构的方法与装置、及图像生成方法与装置
KR20200076461A (ko) 중첩된 비트 표현 기반의 뉴럴 네트워크 처리 방법 및 장치
JP2020123345A (ja) Ganを用いて仮想世界における仮想データから取得したトレーニングデータを生成して、自律走行用ニューラルネットワークの学習プロセスに必要なアノテーションコストを削減する学習方法や学習装置、それを利用したテスト方法やテスト装置
CN116188621A (zh) 基于文本监督的双向数据流生成对抗网络图像生成方法
CN116912629A (zh) 基于多任务学习的通用图像文字描述生成方法及相关装置
CN113554040B (zh) 一种基于条件生成对抗网络的图像描述方法、装置设备
CN117011403A (zh) 生成图像数据的方法及装置、训练方法、电子设备
KR102597184B1 (ko) 가지치기 기반 심층 신경망 경량화에 특화된 지식 증류 방법 및 시스템
CN115984652B (zh) 符号生成***的训练方法、装置、电子设备和存储介质
US20220189171A1 (en) Apparatus and method for prediction of video frame based on deep learning
US20240096071A1 (en) Video processing method using transfer learning and pre-training server

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant