CN116702872A - 基于离线预训练状态转移Transformer模型的强化学习方法和装置 - Google Patents
基于离线预训练状态转移Transformer模型的强化学习方法和装置 Download PDFInfo
- Publication number
- CN116702872A CN116702872A CN202310737435.3A CN202310737435A CN116702872A CN 116702872 A CN116702872 A CN 116702872A CN 202310737435 A CN202310737435 A CN 202310737435A CN 116702872 A CN116702872 A CN 116702872A
- Authority
- CN
- China
- Prior art keywords
- state transition
- state
- reinforcement learning
- training
- transducer model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 230000007704 transition Effects 0.000 title claims abstract description 168
- 238000000034 method Methods 0.000 title claims abstract description 55
- 238000012549 training Methods 0.000 title claims abstract description 51
- 230000002787 reinforcement Effects 0.000 claims abstract description 88
- 230000009471 action Effects 0.000 claims description 16
- 238000012546 transfer Methods 0.000 claims description 7
- 230000006870 function Effects 0.000 claims description 5
- 230000005284 excitation Effects 0.000 claims description 4
- 238000012512 characterization method Methods 0.000 claims 1
- 238000011160 research Methods 0.000 abstract description 3
- 238000013473 artificial intelligence Methods 0.000 abstract description 2
- 230000000007 visual effect Effects 0.000 description 10
- 230000007613 environmental effect Effects 0.000 description 4
- 238000012986 modification Methods 0.000 description 4
- 230000004048 modification Effects 0.000 description 4
- 230000003993 interaction Effects 0.000 description 3
- 230000008569 process Effects 0.000 description 3
- 230000004075 alteration Effects 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 230000008485 antagonism Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/092—Reinforcement learning
-
- A—HUMAN NECESSITIES
- A63—SPORTS; GAMES; AMUSEMENTS
- A63F—CARD, BOARD, OR ROULETTE GAMES; INDOOR GAMES USING SMALL MOVING PLAYING BODIES; VIDEO GAMES; GAMES NOT OTHERWISE PROVIDED FOR
- A63F13/00—Video games, i.e. games using an electronically generated display having two or more dimensions
- A63F13/60—Generating or modifying game content before or while executing the game program, e.g. authoring tools specially adapted for game development or game-integrated level editor
- A63F13/67—Generating or modifying game content before or while executing the game program, e.g. authoring tools specially adapted for game development or game-integrated level editor adaptively or by learning from player actions, e.g. skill level adjustment or by storing successful combat sequences for re-use
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
- G06N3/0455—Auto-encoder networks; Encoder-decoder networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0895—Weakly supervised learning, e.g. semi-supervised or self-supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/094—Adversarial learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/766—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using regression, e.g. by projecting features on hyperplanes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
- G06V10/7753—Incorporation of unlabelled data, e.g. multiple instance learning [MIL]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/40—Scenes; Scene-specific elements in video content
- G06V20/46—Extracting features or characteristics from the video content, e.g. video fingerprints, representative shots or key frames
-
- A—HUMAN NECESSITIES
- A63—SPORTS; GAMES; AMUSEMENTS
- A63F—CARD, BOARD, OR ROULETTE GAMES; INDOOR GAMES USING SMALL MOVING PLAYING BODIES; VIDEO GAMES; GAMES NOT OTHERWISE PROVIDED FOR
- A63F2300/00—Features of games using an electronically generated display having two or more dimensions, e.g. on a television screen, showing representations related to the game
- A63F2300/60—Methods for processing data by generating or executing the game program
- A63F2300/6027—Methods for processing data by generating or executing the game program using adaptive systems learning from user actions, e.g. for skill level adjustment
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y04—INFORMATION OR COMMUNICATION TECHNOLOGIES HAVING AN IMPACT ON OTHER TECHNOLOGY AREAS
- Y04S—SYSTEMS INTEGRATING TECHNOLOGIES RELATED TO POWER NETWORK OPERATION, COMMUNICATION OR INFORMATION TECHNOLOGIES FOR IMPROVING THE ELECTRICAL POWER GENERATION, TRANSMISSION, DISTRIBUTION, MANAGEMENT OR USAGE, i.e. SMART GRIDS
- Y04S10/00—Systems supporting electrical power generation, transmission or distribution
- Y04S10/50—Systems or methods supporting the power network operation or management, involving a certain degree of interaction with the load-side end user applications
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Software Systems (AREA)
- General Health & Medical Sciences (AREA)
- Multimedia (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Databases & Information Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Feedback Control In General (AREA)
Abstract
本发明公开了基于离线预训练状态转移Transformer模型的强化学习方法,属于人工智能技术领域。方法包括基于视频的观测数据离线预训练得到状态转移Transformer模型,以使所述状态转移Transformer模型根据输入的当前状态预测得到下一步状态,并得到从当前状态到下一步状态的状态转移的判别得分;利用所述状态转移Transformer模型,得到强化学习中的状态转移的判别得分作为内在奖励,以使强化学习的智能体根据所述内在奖励进行策略的学习和迭代,得到最优策略。本发明提出的方法比基线算法更具鲁棒性、样本效率和性能,在机器人控制、自动驾驶等领域具有很高的研究价值。
Description
技术领域
本发明涉及人工智能技术领域,尤其涉及一种基于离线预训练状态转移Transformer模型的强化学习方法和装置。
背景技术
从视觉观测数据中训练强化学习策略是一项具有挑战性的研究,其难点主要在于处理高维输入依赖大量计算资源、缺乏明确的动作信息、视觉数据的复杂性需要强大的特征提取技术、时间依赖性等。
目前的训练方法中,有一些采用从零开始的在线强化学习方案,这种方法采样效率低,难以进行有效样本探索和高难度探索,而且在线学习鉴别器的对抗性方法容易受到视觉观察中的噪声或局部变化导致的误分类影响;还有一些仅针对特定任务进行强化学习的策略训练,泛化能力弱,不适合处理开放性任务。因此,现有的观测学习方法适用范围有限:许多仅适用于向量观测环境,在应用于高维视觉观察或视频游戏时效果不好。
发明内容
为了解决现有技术中存在的问题,本发明提供了如下技术方案。
本发明第一方面提供了一种基于离线预训练状态转移Transformer模型的强化学习方法,包括:
基于视频的观测数据,对Transformer模型进行离线预训练得到状态转移Transformer模型,以使所述状态转移Transformer模型根据输入的当前状态预测得到下一步状态,并得到从当前状态到下一步状态的状态转移的判别得分;
利用所述状态转移Transformer模型得到强化学习中的状态转移的判别得分作为强化学习的内在奖励,以使强化学习的智能体根据所述内在奖励进行策略的学习和迭代,得到最优策略。
优选地,所述当前状态按照如下方法从所述视频中提取:在Atari环境中,当前状态通过所述视频中相邻的四帧观测数据堆叠得到;在MineCraft环境中,当前状态包括所述视频中当前的一帧观测数据。
优选地,所述基于视频的观测数据,对Transformer模型进行离线预训练得到状态转移Transformer模型包括:
将相邻两个时间步的状态分别输入到特征编码器中,得到对应的状态表征et和et+1;
将状态表征et输入到Transformer模型中,预测得到下一步的状态表征
将et、et+1和分别输入到状态转移判别器中,得到从et到et+1的真实状态转移的判别得分,以及et到/>的虚假状态转移的判别得分;
迭代训练,使真实状态转移的判别得分增高,虚假状态转移的判别得分降低,直至达到训练目标。
优选地,所述强化学习的智能体根据所述内在奖励进行策略的学习和迭代,得到最优策略包括:强化学习的智能体在环境中交互,在状态转移Transformer模型计算出的内在奖励的激励下通过最大化如下目标J迭代更新策略实现策略提升,最终得到最优策略:
其中,π表示策略,ρ0表示初始状态分布,at表示在当前状态st下根据策略分布π(·|st)执行的动作,(st,st+1)表示当前时刻状态到下一时刻状态的转移,表示状态转移函数,γ为折扣因子,r(st,st+1)表示由状态转移Transformer模型针对(st,st+1)给出的内在奖励,/>表示期望,J表示最大化目标即最大化折扣奖励和的期望。
优选地,所述将et、et+1和分别输入到状态转移判别器中,得到从et到et+1的真实状态转移的判别得分,以及et到/>的虚假状态转移的判别得分,之后还包括:计算虚假状态转移的判别得分与真实状态转移的判别得分之间的差值,得到真实的状态转移与虚假的状态转移之间的差距;
所述利用所述状态转移Transformer模型得到强化学习中的状态转移的判别得分作为强化学习的内在奖励即为:利用真实的状态转移与虚假的状态转移之间的差距作为强化学习的内在奖励。
优选地,利用自监督时序距离预测方法学习状态观测的时序连续的特征表示,同时采用对抗学习的方法,通过判别器判别评分指导在特征表示的空间中精准预测单步转移规律。
优选地,所述强化学习中的状态转移按照如下方法获取:在强化学习中,智能体获取到环境的当前状态,并根据策略基于环境的当前状态选出执行的动作,智能体根据选出的执行的动作与环境交互产生状态转移。
本发明第二方面提供了一种基于离线预训练状态转移Transformer模型的强化学习装置,包括:
状态转移Transformer模型离线预训练模块,用于基于视频的观测数据,对Transformer模型进行离线预训练得到状态转移Transformer模型,以使所述状态转移Transformer模型根据输入的当前状态预测得到下一步状态,并得到从当前状态到下一步状态的状态转移的判别得分;
强化学习策略训练模块,用于利用所述状态转移Transformer模型得到强化学习中的状态转移的判别得分作为内在奖励,以使强化学习的智能体根据所述内在奖励进行策略的学习和迭代,得到最优策略。
本发明第三方面提供了一种存储器,存储有多条指令,所述指令用于实现如第一方面所述的基于离线预训练状态转移Transformer模型的强化学习方法。
本发明第四方面提供了一种电子设备,包括处理器和与所述处理器连接的存储器,所述存储器存储有多条指令,所述指令可被所述处理器加载并执行,以使所述处理器能够执行如第一方面所述的基于离线预训练状态转移Transformer模型的强化学习方法。
本发明的有益效果是:本发明提供了一种两阶段的基于离线预训练状态转移Transformer模型的强化学习方法,为使智能体能够有效地从视觉观察中学习提供了一种创新方法。其中,状态转移Transformer模型能够在仅基于视觉观察的情况下进行离线预训练得到,然后在没有任何环境奖励的情况下指导在线强化学习策略的训练。另外,通过状态转移判别器和自监督时间回归联合预测潜在转换,将自注意力集成到每个模块中以捕捉时间变化,从而在下游强化学习任务中提高了性能。通过在各种Atari和Minecraft环境中对训练得到的策略进行的实验验证了本发明提出的方法比基线算法更具鲁棒性、样本效率和性能。并且,在某些任务中甚至达到了与从显式环境奖励中学习的策略相当的性能。从视觉观察中进行强化学习,对于那些有视频演示可用,但环境交互受限且标记动作既昂贵又危险的情况,本发明提供的方法具有巨大的潜力,譬如在机器人控制、自动驾驶等领域具有很高的研究价值。
附图说明
图1为本发明所述基于离线预训练状态转移Transformer模型的强化学习方法的流程示意图;
图2为本发明所述基于离线预训练状态转移Transformer模型的强化学习方法的框架示意图;
图3为本发明所述基于离线预训练状态转移Transformer模型的强化学习装置的功能模块结构示意图。
具体实施方式
为了更好地理解上述技术方案,下面将结合说明书附图以及具体的实施方式对上述技术方案做详细的说明。
本发明提供的方法可以在如下的终端环境中实施,该终端可以包括一个或多个如下部件:处理器、存储器和显示屏。其中,存储器中存储有至少一条指令,所述指令由处理器加载并执行以实现下述实施例所述的方法。
处理器可以包括一个或者多个处理核心。处理器利用各种接口和线路连接整个终端内的各个部分,通过运行或执行存储在存储器内的指令、程序、代码集或指令集,以及调用存储在存储器内的数据,执行终端的各种功能和处理数据。
存储器可以包括随机存储器(Random Access Memory,RAM),也可以包括只读存储器(Read-Only Memory,ROM)。存储器可用于存储指令、程序、代码、代码集或指令。
显示屏用于显示各个应用程序的用户界面。
除此之外,本领域技术人员可以理解,上述终端的结构并不构成对终端的限定,终端可以包括更多或更少的部件,或者组合某些部件,或者不同的部件布置。比如,终端中还包括射频电路、输入单元、传感器、音频电路、电源等部件,在此不再赘述。
实施例一
如图1、2所示,本发明实施例提供了一种基于离线预训练状态转移Transformer模型的强化学习方法,包括:
S101,基于视频的观测数据,对Transformer模型进行离线预训练得到状态转移Transformer模型,以使所述状态转移Transformer模型根据输入的当前状态预测得到下一步状态,并得到从当前状态到下一步状态的状态转移的判别得分;
S102,利用所述状态转移Transformer模型,得到强化学习中的状态转移的判别得分作为内在奖励,以使强化学习的智能体根据所述内在奖励进行策略的学习和迭代,得到最优策略。
本发明提供的基于离线预训练状态转移Transformer模型的强化学习方法包括两阶段。在第一阶段(阶段一、离线预训练),基于视频的观测数据,离线预训练得到了一个状态转移Transformer模型,可以有效捕捉演示视频中的信息,以预测观测状态的隐层转换。
在第二阶段(阶段二、在线强化学习),利用第一阶段得到的状态转移Transformer模型为下游强化学习任务提供内在奖励,智能体可以仅从这个单独的内在奖励中进行学习和迭代策略,而无需环境奖励的指导。
在步骤S101中,所述当前状态可以按照如下方法从所述视频中提取:在Atari环境中,当前状态通过所述视频中相邻的四帧观测数据堆叠得到;在MineCraft环境中,当前状态包括所述视频中当前的一帧观测数据。如图2所示,相邻两个时间步的当前状态和均为相邻的四帧观测数据堆叠得到的。
其中,Atari环境是经典的街机游戏环境,由于其中每个任务都可以被建模为马尔可夫决策过程,因此成为一种流行的检验强化学习算法在视觉控制任务上应用的测试环境。在Atari环境中,为了确保状态反映游戏动态信息,当前状态由相邻四帧所观测的灰度游戏画面堆叠得到。
MineCraft环境是近期逐渐热门的3D游戏环境,由Minedojo提供模拟器接口,包含数千个开放式开放探索任务。智能体在场景复杂的MineCraft环境场景中完成任务的表现能够更充分地体现算法性能。由于Minedojo模拟器仅支持单帧观测状态转移,因此为对齐Atari中三维状态表示,MineCraft环境中的状态定义为智能体当前观测到的三通道第一人称视角图像。
在本发明实施例中,视频的观测数据可以按照如下方法获得:
其中,Atari环境的观测数据来自Google Dopamine(谷歌的一种开源强化学习框架)。对于每个Atari任务,观测数据集源于DQN(深度Q学习算法)50轮训练后经验回放池中最后储存的十万帧尺寸调整为(84,84)的灰度游戏画面。
MineCraft环境的观测数据来自相关研究Plan4MC(一种基于规划的解决开放式MineCraft任务的方法)。首先训练Plan4MC智能体,采取习得的专家策略收集五万帧尺寸为(160,256,3)的第一人称游戏画面构成专家观测数据集。
执行步骤S101,如图2所示,基于视频的观测数据,对Transformer模型进行离线预训练得到状态转移Transformer模型可以包括:
将相邻两个时间步的状态和/>分别输入到特征编码器中,得到对应的状态表征et和et+1;
将状态表征et输入到Transformer模型中,预测得到下一步的状态表征
将et、et+1和分别输入到状态转移判别器中,得到从et到et+1的真实状态转移的判别得分,以及et到/>的虚假状态转移的判别得分;
迭代训练,使真实状态转移的判别得分尽可能高,虚假状态转移的判别得分尽可能低,直至达到训练目标。
进一步地,所述将et、et+1和分别输入到状态转移判别器中,得到从et到et+1的真实状态转移的判别得分,以及et到/>的虚假状态转移的判别得分,之后还包括:计算虚假状态转移的判别得分与真实状态转移的判别得分之间的差值,得到真实的状态转移与虚假的状态转移之间的差距;
所述利用所述状态转移Transformer模型得到强化学习中的状态转移的判别得分作为强化学习的内在奖励即为:利用真实的状态转移与虚假的状态转移之间的差距作为强化学习的内在奖励。
另外,本发明实施例中利用自监督时序距离预测方法学习状态观测的时序连续的特征表示,同时采用对抗学习的方法,通过判别器判别评分指导在特征表示的空间中精准预测单步转移规律。从而,预训练完成的状态转移Transformer和判别器在强化学习过程中针对在线采集的观测序列提供内在奖励,从而提高下游强化学习任务的性能。
执行步骤S102,在强化学习中,智能体获取到环境的当前状态,并根据策略基于环境的当前状态选出执行的动作,智能体根据选出的执行的动作与环境交互产生状态转移。利用步骤S101训练得到的所述状态转移Transformer模型得到强化学习中的状态转移的判别得分作为强化学习的内在奖励,以使强化学习的智能体根据所述内在奖励进行策略的学习和迭代,得到最优策略。
具体的,可如图2所示,在强化学习的策略训练中,智能体获取到视觉观测环境的当前状态,然后智能体根据策略πθ,基于当前状态选出执行的动作at,再然后智能体根据选出的执行的动作at与环境进行交互,从而产生状态转移。之后,可以利用离线预训练得到的状态转移Transformer模型,针对智能体与环境交互产生的状态转移的情况得到内在奖励最后智能体根据内在奖励/>更新策略πθ。迭代训练,直至得到最优策略。
在本发明实施例中,所述强化学习的智能体根据所述内在奖励进行策略的学习和迭代,得到最优策略可以包括:强化学习的智能体在环境中交互,在状态转移Transformer模型计算出的内在奖励的激励下通过最大化如下目标J迭代更新策略实现策略提升,最终得到最优策略:
其中,π表示策略,ρ0表示初始状态分布,at表示在当前状态t下根据策略分布π(·|st)执行的动作,(st,st+1)表示当前时刻状态到下一时刻状态的转移,表示状态转移函数,γ为折扣因子,r(st,st+1)表示由状态转移Transformer模型针对(st,st+1)给出的内在奖励,表示期望,J表示最大化目标即最大化折扣奖励和的期望。
实施例二
如图3所示,本发明的另一方面还包括和前述方法流程完全对应一致的功能模块架构,即本发明实施例还提供了基于离线预训练状态转移Transformer模型的强化学习装置,包括:
状态转移Transformer模型离线预训练模块201,用于基于视频的观测数据,对Transformer模型进行离线预训练得到状态转移Transformer模型,以使所述状态转移Transformer模型根据输入的当前状态预测得到下一步状态,并得到从当前状态到下一步状态的状态转移的判别得分;
强化学习策略训练模块202,用于利用所述状态转移Transformer模型得到强化学习中的状态转移的判别得分作为内在奖励,以使强化学习的智能体根据所述内在奖励进行策略的学习和迭代,得到最优策略。
在状态转移Transformer模型离线预训练模块201中,所述当前状态按照如下方法从所述视频中提取:在Atari环境中,当前状态通过所述视频中相邻的四帧观测数据堆叠得到;在MineCraft环境中,当前状态包括所述视频中当前的一帧观测数据。
进一步地,所述基于视频的观测数据,对Transformer模型进行离线预训练得到状态转移Transformer模型包括:
将相邻两个时间步的状态分别输入到特征编码器中,得到对应的状态表征et和et+1;
将状态表征et输入到Transformer模型中,预测得到下一步的状态表征
将et、et+1和分别输入到状态转移判别器中,得到从et到et+1的真实状态转移的判别得分,以及et到/>的虚假状态转移的判别得分;
迭代训练,使真实状态转移的判别得分尽可能高,虚假状态转移的判别得分尽可能低,直至达到训练目标。
所述将et、et+1和分别输入到状态转移判别器中,得到从et到et+1的真实状态转移的判别得分,以及et到/>的虚假状态转移的判别得分,之后还包括:计算虚假状态转移的判别得分与真实状态转移的判别得分之间的差值,得到真实的状态转移与虚假的状态转移之间的差距;
所述利用所述状态转移Transformer模型得到强化学习中的状态转移的判别得分作为强化学习的内在奖励即为:利用真实的状态转移与虚假的状态转移之间的差距作为强化学习的内在奖励。
利用自监督时序距离预测方法学习状态观测的时序连续的特征表示,同时采用对抗学习的方法,通过判别器判别评分指导在特征表示的空间中精准预测单步转移规律。
在强化学习策略训练模块202中,所述强化学习中的状态转移按照如下方法获取:在强化学习中,智能体获取到环境的当前状态,并根据策略基于环境的当前状态选出执行的动作,智能体根据选出的执行的动作与环境交互产生状态转移。
进一步地,所述强化学习的智能体根据所述内在奖励进行策略的学习和迭代,得到最优策略包括:强化学习的智能体在环境中交互,在状态转移Transformer模型计算出的内在奖励的激励下通过最大化如下目标J迭代更新策略实现策略提升,最终得到最优策略:
其中,π表示策略,ρ0表示初始状态分布,at表示在当前状态st下根据策略分布π(·|st)执行的动作,(st,st+1)表示当前时刻状态到下一时刻状态的转移,表示状态转移函数,γ为折扣因子,r(st,st+1)表示由状态转移Transformer模型针对(st,st+1)给出的内在奖励,/>表示期望,J表示最大化目标即最大化折扣奖励和的期望。
该装置可通过上述实施例一提供的基于离线预训练状态转移Transformer模型的强化学习方法实现,具体的实现方法可参见实施例一中的描述,在此不再赘述。
本发明还提供了一种存储器,存储有多条指令,所述指令用于实现如实施例一所述的基于离线预训练状态转移Transformer模型的强化学习方法。
本发明还提供了一种电子设备,包括处理器和与所述处理器连接的存储器,所述存储器存储有多条指令,所述指令可被所述处理器加载并执行,以使所述处理器能够执行如实施例一所述的基于离线预训练状态转移Transformer模型的强化学习方法。
采用本发明提供的技术方案,从视觉观察中进行强化学习,对于那些有视频演示可用,但环境交互受限且标记动作既昂贵又危险的情况具有巨大的潜力,譬如在机器人控制、自动驾驶等领域具有很高的研究价值。
尽管已描述了本发明的优选实施例,但本领域内的技术人员一旦得知了基本创造性概念,则可对这些实施例作出另外的变更和修改。所以,所附权利要求意欲解释为包括优选实施例以及落入本发明范围的所有变更和修改。显然,本领域的技术人员可以对本发明进行各种改动和变型而不脱离本发明的精神和范围。这样,倘若本发明的这些修改和变型属于本发明权利要求及其等同技术的范围之内,则本发明也意图包含这些改动和变型在内。
Claims (10)
1.一种基于离线预训练状态转移Transformer模型的强化学习方法,其特征在于,包括:
基于视频的观测数据,对Transformer模型进行离线预训练得到状态转移Transformer模型,以使所述状态转移Transformer模型根据输入的当前状态预测得到下一步状态,并得到从当前状态到下一步状态的状态转移的判别得分;
利用所述状态转移Transformer模型得到强化学习中的状态转移的判别得分作为强化学习的内在奖励,以使强化学习的智能体根据所述内在奖励进行策略的学习和迭代,得到最优策略。
2.如权利要求1所述的基于离线预训练状态转移Transformer模型的强化学习方法,其特征在于,所述当前状态按照如下方法从所述视频中提取:在Atari环境中,当前状态通过所述视频中相邻的四帧观测数据堆叠得到;在MineCraft环境中,当前状态包括所述视频中当前的一帧观测数据。
3.如权利要求1所述的基于离线预训练状态转移Transformer模型的强化学习方法,其特征在于,所述基于视频的观测数据,对Transformer模型进行离线预训练得到状态转移Transformer模型包括:
将相邻两个时间步的状态分别输入到特征编码器中,得到对应的状态表征et和et+1;
将状态表征et输入到Transformer模型中,预测得到下一步的状态表征
将et、et+1和分别输入到状态转移判别器中,得到从et到et+1的真实状态转移的判别得分,以及et到/>的虚假状态转移的判别得分;
迭代训练,使真实状态转移的判别得分增高,虚假状态转移的判别得分降低,直至达到训练目标。
4.如权利要求3所述的基于离线预训练状态转移Transformer模型的强化学习方法,其特征在于,所述将et、et+1和分别输入到状态转移判别器中,得到从et到et+1的真实状态转移的判别得分,以及et到/>的虚假状态转移的判别得分,之后还包括:计算虚假状态转移的判别得分与真实状态转移的判别得分之间的差值,得到真实的状态转移与虚假的状态转移之间的差距;
所述利用所述状态转移Transformer模型得到强化学习中的状态转移的判别得分作为强化学习的内在奖励即为:利用真实的状态转移与虚假的状态转移之间的差距作为强化学习的内在奖励。
5.如权利要求3所述的基于离线预训练状态转移Transformer模型的强化学习方法,其特征在于,利用自监督时序距离预测方法学习状态观测的时序连续的特征表示,同时采用对抗学习的方法,通过判别器判别评分指导在特征表示的空间中精准预测单步转移规律。
6.如权利要求1所述的基于离线预训练状态转移Transformer模型的强化学习方法,其特征在于,所述强化学习中的状态转移按照如下方法获取:在强化学习中,智能体获取到环境的当前状态,并根据策略基于环境的当前状态选出执行的动作,智能体根据选出的执行的动作与环境交互产生状态转移。
7.如权利要求1所述的基于离线预训练状态转移Transformer模型的强化学习方法,其特征在于,所述强化学习的智能体根据所述内在奖励进行策略的学习和迭代,得到最优策略包括:强化学习的智能体在环境中交互,在状态转移Transformer模型计算出的内在奖励的激励下通过最大化如下目标J迭代更新策略实现策略提升,最终得到最优策略:
其中,π表示策略,ρ0表示初始状态分布,at表示在当前状态st下根据策略分布π(·|st)执行的动作,(st,st+1)表示当前时刻状态到下一时刻状态的转移,表示状态转移函数,γ为折扣因子,r(st,st+1)表示由状态转移Transformer模型针对(st,st+1)给出的内在奖励,表示期望,J表示最大化目标即最大化折扣奖励和的期望。
8.一种基于离线预训练状态转移Transformer模型的强化学习装置,其特征在于,包括:
状态转移Transformer模型离线预训练模块,用于基于视频的观测数据,对Transformer模型进行离线预训练得到状态转移Transformer模型,以使所述状态转移Transformer模型根据输入的当前状态预测得到下一步状态,并得到从当前状态到下一步状态的状态转移的判别得分;
强化学习策略训练模块,用于利用所述状态转移Transformer模型得到强化学习中的状态转移的判别得分作为内在奖励,以使强化学习的智能体根据所述内在奖励进行策略的学习和迭代,得到最优策略。
9.一种存储器,其特征在于,存储有多条指令,所述指令用于实现如权利要求1-7任一项所述的基于离线预训练状态转移Transformer模型的强化学习方法。
10.一种电子设备,其特征在于,包括处理器和与所述处理器连接的存储器,所述存储器存储有多条指令,所述指令可被所述处理器加载并执行,以使所述处理器能够执行如权利要求1-7任一项所述的基于离线预训练状态转移Transformer模型的强化学习方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310737435.3A CN116702872A (zh) | 2023-06-20 | 2023-06-20 | 基于离线预训练状态转移Transformer模型的强化学习方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310737435.3A CN116702872A (zh) | 2023-06-20 | 2023-06-20 | 基于离线预训练状态转移Transformer模型的强化学习方法和装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116702872A true CN116702872A (zh) | 2023-09-05 |
Family
ID=87825438
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310737435.3A Pending CN116702872A (zh) | 2023-06-20 | 2023-06-20 | 基于离线预训练状态转移Transformer模型的强化学习方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116702872A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117933346A (zh) * | 2024-03-25 | 2024-04-26 | 之江实验室 | 一种基于自监督强化学习的即时奖励学习方法 |
CN117953351A (zh) * | 2024-03-27 | 2024-04-30 | 之江实验室 | 一种基于模型强化学习的决策方法 |
-
2023
- 2023-06-20 CN CN202310737435.3A patent/CN116702872A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117933346A (zh) * | 2024-03-25 | 2024-04-26 | 之江实验室 | 一种基于自监督强化学习的即时奖励学习方法 |
CN117953351A (zh) * | 2024-03-27 | 2024-04-30 | 之江实验室 | 一种基于模型强化学习的决策方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Greydanus et al. | Visualizing and understanding atari agents | |
Mousavi et al. | Deep reinforcement learning: an overview | |
Lei et al. | Dynamic path planning of unknown environment based on deep reinforcement learning | |
CN107403426B (zh) | 一种目标物体检测方法及设备 | |
CN105637540B (zh) | 用于强化学习的方法和设备 | |
US20180268292A1 (en) | Learning efficient object detection models with knowledge distillation | |
CN116702872A (zh) | 基于离线预训练状态转移Transformer模型的强化学习方法和装置 | |
CN111144580B (zh) | 一种基于模仿学习的层级强化学习训练方法和装置 | |
CN110770759B (zh) | 神经网络*** | |
de la Cruz et al. | Pre-training with non-expert human demonstration for deep reinforcement learning | |
CN111507378A (zh) | 训练图像处理模型的方法和装置 | |
US11580378B2 (en) | Reinforcement learning for concurrent actions | |
CN111602144A (zh) | 生成指令序列以控制执行任务的代理的生成神经网络*** | |
Zieliński et al. | 3D robotic navigation using a vision-based deep reinforcement learning model | |
CN111352419B (zh) | 基于时序差分更新经验回放缓存的路径规划方法及*** | |
EP2363251A1 (en) | Robot with Behavioral Sequences on the basis of learned Petri Net Representations | |
CN111902812A (zh) | 电子装置及其控制方法 | |
Bertoin et al. | Local feature swapping for generalization in reinforcement learning | |
CN113407820B (zh) | 利用模型进行数据处理的方法及相关***、存储介质 | |
Chen et al. | Toward a brain-inspired system: Deep recurrent reinforcement learning for a simulated self-driving agent | |
Ji et al. | Improving decision-making efficiency of image game based on deep Q-learning | |
Shao et al. | Visual navigation with actor-critic deep reinforcement learning | |
CN115797517B (zh) | 虚拟模型的数据处理方法、装置、设备和介质 | |
CN112121419A (zh) | 虚拟对象控制方法、装置、电子设备以及存储介质 | |
Saito et al. | Python reinforcement learning projects: eight hands-on projects exploring reinforcement learning algorithms using TensorFlow |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |