CN111476020A - 一种基于元强化学习的文本生成方法 - Google Patents

一种基于元强化学习的文本生成方法 Download PDF

Info

Publication number
CN111476020A
CN111476020A CN202010156433.1A CN202010156433A CN111476020A CN 111476020 A CN111476020 A CN 111476020A CN 202010156433 A CN202010156433 A CN 202010156433A CN 111476020 A CN111476020 A CN 111476020A
Authority
CN
China
Prior art keywords
text
model
text generation
generating
data
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202010156433.1A
Other languages
English (en)
Other versions
CN111476020B (zh
Inventor
赵婷婷
宋亚静
王嫄
任德华
杨巨成
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tianjin University of Science and Technology
Original Assignee
Tianjin University of Science and Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tianjin University of Science and Technology filed Critical Tianjin University of Science and Technology
Priority to CN202010156433.1A priority Critical patent/CN111476020B/zh
Publication of CN111476020A publication Critical patent/CN111476020A/zh
Application granted granted Critical
Publication of CN111476020B publication Critical patent/CN111476020B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02ATECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
    • Y02A90/00Technologies having an indirect contribution to adaptation to climate change
    • Y02A90/10Information and communication technologies [ICT] supporting adaptation to climate change, e.g. for weather forecasting or climate simulation

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Computing Systems (AREA)
  • Artificial Intelligence (AREA)
  • Mathematical Physics (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • General Engineering & Computer Science (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Medical Informatics (AREA)
  • Machine Translation (AREA)

Abstract

本发明涉及一种基于元强化学习的文本生成方法,其技术特点是:收集不同类型的文本数据作为不同任务的划分;收集文本数据中随机采取某一任务的数据;采用处理序列数据的递归型神经网络构造文本生成模型;生成K条文本轨迹;利用文本生成轨迹对文本生成模型进行少次策略梯度更新,得到更新后的文本生成模型;生成新的文本轨迹;在多个任务上分别对文本生成模型进行更新并采样,得到文本生成轨迹的表现误差;对原始文本生成模型参数进行二次梯度更新训练至收敛。本发明在强化学习利用递归神经网络进行文本生成的基础上进行改良,利用元强化学习训练智能体,将在多个任务上学习到的经验迁移到目标任务中,可快速实现不同场景或语境下的文本生成。

Description

一种基于元强化学习的文本生成方法
技术领域
本发明属于计算机自然语言处理技术领域,尤其是一种基于元强化学习的文本生成方法。
背景技术
自然语言处理(NLP),特别是自然语言生成(NLG)问题,长期以来一直被认为是最具挑战性的计算任务之一。自然语言生成是让计算机具有与人一样的表达和写作能力的技术,它可以根据一些关键信息及其在机器内部的表达形式,经过规划自动生成一段高质量的自然语言文本。从最开始的模式匹配生成,通过一些简单的句法、语法规则来组织生成文本;到后来基于统计概率模型;再到现在伴随着深度学习的快速发展,基于深度学习的自然语言生成技术有了较为突出的进展,各种神经网络被提议出生成准确、自然和多样化的文本。
强化学习(reinforcement learning,简称RL)作为机器学习中的一个重要研究领域,以试错的机制与环境进行交互,通过最大化累积奖赏来学习最优策略。该技术可以将使用递归神经网络生成文本看作是一个马尔可夫决策过程(MDP),其局部最优策略可以通过强化学习找到,这在最近的研究中取得了很好的结果。然而,现有的文本生成方法通常是针对特定领域开发的。而现实世界中的自然语言往往是多个领域的,且不同领域间的文本在语法、语义等规则上是一致的。此外,神经网络的训练往往需要大量的数据,标注充分学习数据需要花费大量的时间与金钱。因此,样本的收集及场景的适应能力是文本生成应用中的一个重要瓶颈问题。
发明内容
本发明的目的在于克服现有技术的不足,提出一种基于元强化学习的文本生成方法,用于解决真实世界中语言生成模型快速适应不同场景进行文本生成以及个别场景下学习样本不好收集的瓶颈问题。
本发明解决其技术问题是采取以下技术方案实现的:
一种基于元强化学习的文本生成方法,包括以下步骤:
步骤1、收集不同类型的文本数据作为不同任务的划分;
步骤2、从步骤1收集的文本数据中随机采取某一任务τi的数据;
步骤3、采用处理序列数据的递归型神经网络构造文本生成模型fθ
步骤4、利用文本生成模型fθ生成K条文本轨迹Di
步骤5、利用文本生成轨迹Di对文本生成模型fθ进行少次策略梯度更新,得到更新后的文本生成模型fθ';
步骤6、利用文本生成模型fθ'生成新的文本轨迹Di';
步骤7、重复步骤2至步骤6,在多个任务上分别对文本生成模型进行更新并采样,得到文本生成轨迹的表现误差;
步骤8、利用步骤7得到文本生成轨迹的表现误差对原始文本生成模型参数进行二次梯度更新训练至收敛。
所述步骤1中收集不同类型的文本数据为自然语言的不同场景。
所述步骤3中递归型神经网络为强化学习中的智能体,其输出一个概率密度函数p(yt|Y1:t-1,),其中,Y1:t-1为文本生成模型在t时刻的状态st,表示已生成的字符序列串,yt为文本生成模型在t时刻的动作at,表示当前选择的字符。
所述步骤4采用REINFORCE方法对参数进行少次梯度更新,将奖励函数设定为真实文本数据与生成文本数据的双语评估替补分数。
所述步骤8采用文本生成模型fθ'的采样数据对原始生成模型fθ进行二次梯度更新。
本发明的优点和积极效果是:
1、本发明设计合理,其通过递归神经网络分析输入信息之间的整体关联,处理文本序列以进行生成,然后在多个场景上利用更新后模型的采样轨迹表现对原始模型进行更新,同时,利用元学习训练模型参数,使得模型参数只要经过少量次数的梯度更新就能实现在新场景文本生成任务上的快速学习。这种学会学习的智能体是通往可持续学习多项新任务的多面智能体的必经之路。因此,本发明不仅可以使智能体具备快速学习、快速适应新环境的能力,而且在给定样本数量较少或采集样本的预算有限的情形下,具有快速、准确的特点。
2、本发明使得智能体在有少量文本生成学习样本的情况下,也可以利用文本生成模型并通过少量次数的梯度更新适应新场景,摆脱了文本生成应用对大量学习样本的要求,在一定程度上解决了语言生成模型在某些场景下数据不足的瓶颈问题。
3、本发明在强化学习利用递归神经网络进行文本生成的基础上进行改良,利用元强化学习训练智能体,将在多个任务上学习到的经验迁移到目标任务中,在一定程度上解决了语言生成模型在实际应用中需要大量学习样本以及难以快速适应不同场景的问题,从而可以快速实现不同场景或语境下的文本生成。
附图说明
图1是本发明的元强化学习文本生成图。
具体实施方式
以下结合附图对本发明做进一步详述。
元强化学习(Meta Reinforcement Learning,简称Meta RL)是将元学习应用到强化学习的一个研究方向,其核心的想法就是希望智能体在学习大量的强化学习任务中获取足够的先验知识,然后在面对新的强化学习任务时能够学的更快,学的更好,能够快速自适应新的学习环境。
如图1所示,首先从数据库Dtrain中选取任一任务τi的文本数据作为当前的生成环境:初始化文本生成模型M,将其生成文本与真实文本比较作为训练误差Lossn,使用策略梯度方法对生成模型M进行少次内部梯度更新为M'n。然后使用更新后的模型M'n继续采样文本轨迹,计算其表现误差Lossn',重复以上步骤,对n个不同任务计算误差Lossn'。最后,对n个误差求和,通过多个更新后模型的表现对原始生成模型进行外部梯度更新。
在强化学习进行文本生成时,采用的是递归神经网络。递归神经网络带有一个指向自身的环,用来表示它可以传递当前时刻处理的信息给下一时刻使用。递归神经网络的输入是一整个序列,也就是x=[x1,…xt-1,xt,xt+1,…xT],xt是网络某一时刻的输入。网络t时刻的隐藏状态ht是关于前一时刻的隐藏状态ht-1和当前时刻的输入xt的函数,即ht结合了历史信息及当前的输入信息。网络的输出是关于ht的函数,在结合了历史信息和当前输入的情况下,递归神经网络能够很好地处理序列问题,能够预测下一时刻状态的输出和自身的隐状态。强化学习的当前状态是在t时刻已生成的字符串st=Y1:t-1=(y1,…,yt-1),当前动作为t时刻选择的字符yt,在确定了某一字符yt后,状态由st=Y1:t-1=(y1,…,yt-1)确定性转移到st'=Y1:t=(y1,…,yt)。
本发明在上述数学模型及目标函数的基础上,通过运用元强化学习的快速适应性能,在不同场景下训练生成模型,对生成模型适应场景的能力进行提升。通过学习不同场景下文本的生成能力,来解决文本生成应用对大量学习样本的要求,从而应对真实世界中语言生成模型快速适应不同场景进行文本生成以及个别场景下学习样本不好收集的瓶颈问题。
本发明的设计思路为:整体生成模型分为元学习外部梯度更新及采用递归神经网络进行强化学习文本生成的内部更新两个部分。其中采用递归型神经网络作为强化学习智能体,结合历史数据,进行少次内部策略梯度更新,不断训练为生成更符合人类自然语言的文本数据。采用元学习的方法通过更新后模型的表现对原始生成模型进行二次梯度更新,以增强智能体对环境的适应能力,从而最终得到具有快速学习能力的文本生成模型。本发明采用元强化学习技术,可以快速适应新任务且仅需少次训练数据的能力,从而应对真实世界中语言生成模型快速适应不同场景进行文本生成以及个别场景下学习样本不好收集的瓶颈问题。
基于上述设计思路,本发明首先收集不同类型的文本数据;其次在某一类型的数据上训练强化学习智能体,进行少次内部策略梯度更新,使其具备生成更符合人类自然语言文本数据的能力;通过在多个类型数据上的重复训练,最后通过更新后模型的表现对原始生成模型进行二次梯度更新,使其具有快速学习能力。具体方法包括以下步骤:
步骤1、收集不同类型的文本数据作为不同任务的划分。
本发明利用元强化学习能够快速适应新任务且仅需少次训练数据的能力,和递归神经网络在处理序列问题方面的优势,从而应对真实世界中语言生成模型快速适应不同场景进行文本生成以及个别场景下学习样本不好收集的瓶颈问题。在本步骤中,需要收集不同类型的文本数据作为元学习不同任务,以学习先验知识能够快速适应新的场景,此外,该数据作为强化学习中奖励函数设定的依赖数据,帮助生成模型进行训练。
本步骤收集不同类型的文本数据可以是自然语言的不同场景,例如:天气、科技、餐馆、篮球等场景。
步骤2、从步骤1收集的文本数据中随机采取某一任务τi的数据。
步骤3、采用处理序列数据的递归型神经网络构造文本生成模型fθ,即递归神经网络模型。
在本步骤中,使用递归神经网络不仅能够识别个体输入,更能分析输入信息之间的整体关联,是一种具有记忆力功能的神经网络。
递归神经网络(RNN)被看作是强化学习中的智能体,其输出一个概率密度函数p(yt|Y1:t-1,),而不是一个确定性预测yt。这里的Y1:t-1表示文本生成模型在t时刻的状态st,即已生成的字符序列串,yt表示文本生成模型在t时刻的动作at即当前选择的字符。
步骤4、利用步骤3的文本生成模型生成K条文本轨迹Di
步骤5、利用步骤4的文本生成轨迹Di对文本生成模型进行少次策略梯度更新。
在本步骤中,策略梯度更新是采用REINFORCE方法对参数进行少次梯度更新,其中,奖励函数设定为真实文本数据与生成文本数据的双语评估替补分数,即BLEU(Bilingual Evaluation Understudy)分数。
步骤6、利用步骤5得到的文本生成模型fθ'生成文本轨迹Di'。
步骤7、重复步骤2至步骤6,在多个任务上分别对文本生成模型进行更新并采样,得到文本生成轨迹的表现误差。
步骤8、利用步骤7得到文本生成轨迹的表现误差对原始文本生成模型参数进行二次梯度更新训练至收敛。
在本步骤中,二次梯度更新是利用文本生成模型fθ'的采样数据对原始生成模型fθ进行二次梯度更新。
需要强调的是,本发明所述的实施例是说明性的,而不是限定性的,因此本发明包括并不限于具体实施方式中所述的实施例,凡是由本领域技术人员根据本发明的技术方案得出的其他实施方式,同样属于本发明保护的范围。

Claims (5)

1.一种基于元强化学习的文本生成方法,其特征在于包括以下步骤:
步骤1、收集不同类型的文本数据作为不同任务的划分;
步骤2、从步骤1收集的文本数据中随机采取某一任务τi的数据;
步骤3、采用处理序列数据的递归型神经网络构造文本生成模型fθ
步骤4、利用文本生成模型fθ生成K条文本轨迹Di
步骤5、利用文本生成轨迹Di对文本生成模型fθ进行少次策略梯度更新,得到更新后的文本生成模型f′θ
步骤6、利用文本生成模型f′θ生成新的文本轨迹D′i
步骤7、重复步骤2至步骤6,在多个任务上分别对文本生成模型进行更新并采样,得到文本生成轨迹的表现误差;
步骤8、利用步骤7得到文本生成轨迹的表现误差对原始文本生成模型参数进行二次梯度更新训练至收敛。
2.根据权利要求1所述的一种基于元强化学习的文本生成方法,其特征在于:所述步骤1中收集不同类型的文本数据为自然语言的不同场景。
3.根据权利要求1所述的一种基于元强化学习的文本生成方法,其特征在于:所述步骤3中递归型神经网络为强化学习中的智能体,其输出一个概率密度函数p(yt|Y1:t-1,),其中,Y1:t-1为文本生成模型在t时刻的状态st,表示已生成的字符序列串,yt为文本生成模型在t时刻的动作at,表示当前选择的字符。
4.根据权利要求1所述的一种基于元强化学习的文本生成方法,其特征在于:所述步骤4采用REINFORCE方法对参数进行少次梯度更新,将奖励函数设定为真实文本数据与生成文本数据的双语评估替补分数。
5.根据权利要求1所述的一种基于元强化学习的文本生成方法,其特征在于:所述步骤8采用文本生成模型f′θ的采样数据对原始生成模型fθ进行二次梯度更新。
CN202010156433.1A 2020-03-09 2020-03-09 一种基于元强化学习的文本生成方法 Active CN111476020B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010156433.1A CN111476020B (zh) 2020-03-09 2020-03-09 一种基于元强化学习的文本生成方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010156433.1A CN111476020B (zh) 2020-03-09 2020-03-09 一种基于元强化学习的文本生成方法

Publications (2)

Publication Number Publication Date
CN111476020A true CN111476020A (zh) 2020-07-31
CN111476020B CN111476020B (zh) 2023-07-25

Family

ID=71748074

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010156433.1A Active CN111476020B (zh) 2020-03-09 2020-03-09 一种基于元强化学习的文本生成方法

Country Status (1)

Country Link
CN (1) CN111476020B (zh)

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20150100530A1 (en) * 2013-10-08 2015-04-09 Google Inc. Methods and apparatus for reinforcement learning
US20190354859A1 (en) * 2018-05-18 2019-11-21 Deepmind Technologies Limited Meta-gradient updates for training return functions for reinforcement learning systems

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20150100530A1 (en) * 2013-10-08 2015-04-09 Google Inc. Methods and apparatus for reinforcement learning
US20190354859A1 (en) * 2018-05-18 2019-11-21 Deepmind Technologies Limited Meta-gradient updates for training return functions for reinforcement learning systems

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
冯少迪;: "基于强化学习的自然语言处理技术", 数码世界, no. 03 *
唐振韬;邵坤;赵冬斌;朱圆恒;: "深度强化学习进展:从AlphaGo到AlphaGo Zero", 控制理论与应用, no. 12 *
赵星宇;丁世飞;: "深度强化学习研究综述", 计算机科学, no. 07 *

Also Published As

Publication number Publication date
CN111476020B (zh) 2023-07-25

Similar Documents

Publication Publication Date Title
Tian et al. Off-policy reinforcement learning for efficient and effective gan architecture search
CN108724182B (zh) 基于多类别模仿学习的端到端游戏机器人生成方法及***
CN110991027A (zh) 一种基于虚拟场景训练的机器人模仿学习方法
CN110766044B (zh) 一种基于高斯过程先验指导的神经网络训练方法
CN109840595B (zh) 一种基于群体学习行为特征的知识追踪方法
CN112172813B (zh) 基于深度逆强化学习的模拟驾驶风格的跟车***及方法
CN106850289B (zh) 结合高斯过程与强化学习的服务组合方法
CN113312925B (zh) 一种基于自强化学习的遥感影像文本生成及优化方法
El Gourari et al. The Implementation of Deep Reinforcement Learning in E‐Learning and Distance Learning: Remote Practical Work
KR20240034804A (ko) 자동 회귀 언어 모델 신경망을 사용하여 출력 시퀀스 평가
CN113313265A (zh) 基于带噪声专家示范的强化学习方法
Liu et al. Smart city moving target tracking algorithm based on quantum genetic and particle filter
CN115860107A (zh) 一种基于多智能体深度强化学习的多机探寻方法及***
CN115269861A (zh) 基于生成式对抗模仿学习的强化学习知识图谱推理方法
CN114911969A (zh) 一种基于用户行为模型的推荐策略优化方法和***
He et al. Influence-augmented online planning for complex environments
CN111476020A (zh) 一种基于元强化学习的文本生成方法
CN112297012B (zh) 一种基于自适应模型的机器人强化学习方法
CN115453880A (zh) 基于对抗神经网络的用于状态预测的生成模型的训练方法
CN115212549A (zh) 一种对抗场景下的对手模型构建方法及存储介质
CN114154582A (zh) 基于环境动态分解模型的深度强化学习方法
Li et al. Policy gradient methods with gaussian process modelling acceleration
CN113821452A (zh) 根据被测***测试表现动态生成测试案例的智能测试方法
CN115668215A (zh) 用于训练参数化策略的装置和方法
CN113139644A (zh) 一种基于深度蒙特卡洛树搜索的信源导航方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant