JP6704583B2 - 学習システムおよび学習方法 - Google Patents

学習システムおよび学習方法 Download PDF

Info

Publication number
JP6704583B2
JP6704583B2 JP2016253169A JP2016253169A JP6704583B2 JP 6704583 B2 JP6704583 B2 JP 6704583B2 JP 2016253169 A JP2016253169 A JP 2016253169A JP 2016253169 A JP2016253169 A JP 2016253169A JP 6704583 B2 JP6704583 B2 JP 6704583B2
Authority
JP
Japan
Prior art keywords
parameter
differential value
learning
updating
value
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
JP2016253169A
Other languages
English (en)
Other versions
JP2018106489A (ja
Inventor
育郎 佐藤
育郎 佐藤
亮 藤崎
亮 藤崎
哲弘 野村
哲弘 野村
洋介 大山
洋介 大山
松岡 聡
聡 松岡
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tokyo Institute of Technology NUC
Denso IT Laboratory Inc
Original Assignee
Tokyo Institute of Technology NUC
Denso IT Laboratory Inc
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tokyo Institute of Technology NUC, Denso IT Laboratory Inc filed Critical Tokyo Institute of Technology NUC
Priority to JP2016253169A priority Critical patent/JP6704583B2/ja
Priority to US15/795,691 priority patent/US11521057B2/en
Priority to CN201711425028.XA priority patent/CN108241889A/zh
Publication of JP2018106489A publication Critical patent/JP2018106489A/ja
Application granted granted Critical
Publication of JP6704583B2 publication Critical patent/JP6704583B2/ja
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F17/00Digital computing or data processing equipment or methods, specially adapted for specific functions
    • G06F17/10Complex mathematical operations
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Mathematical Physics (AREA)
  • General Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Artificial Intelligence (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Algebra (AREA)
  • Computational Mathematics (AREA)
  • Mathematical Analysis (AREA)
  • Mathematical Optimization (AREA)
  • Pure & Applied Mathematics (AREA)
  • Databases & Information Systems (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Image Analysis (AREA)

Description

本発明は、ニューラルネットワーク用のパラメタを更新する学習システムおよび学習方法に関する。
画像認識の分野において、一般物体認識と呼ばれる問題がある。これは、画像の中に存在する鳥や車といった物体の種別(クラス)を推定する問題である。近年、一般物体認識問題の認識性能の改善が目覚ましい。これは、とりわけ層数の多い畳み込みニューラルネットワーク(Convolution Neural Network、以下CNNという。例えば、非特許文献1参照)によるところが大きい。
画像認識の分野では過去様々な認識アルゴリズムが提案されてきたが、学習データ(入力データと正解との組)が膨大になるにつれ、CNNが他のアルゴリズムの認識性能を上回る傾向となっている。CNNは、モデルの表現能力が高い反面、学習データの特徴に過度に特化してしまう「過学習」と呼ばれる問題があることが従来指摘されてきた。しかしながら、近年の学習データ量が過学習の問題の回避を可能にするレベルにまで増大しつつある。
Ren Wu, Shengen Yan, Yi Shan, Qingqing Dang, and Gang Sun, "Deep Image: Scaling up Image Recognition", arXiv:1501.02876v2. C. M. Bishop, "Neural Networks for Pattern Recognition", p267-268, Clarendon Press (1996). Y. Nesterov, "A method for unconstrained convex minimization problem with the rate of convergence o(1/k2)". Doklady ANSSSR (translated as Soviet.Math.Docl.), vol. 269, pp. 543-547 (1983). I. Sutskever, "Training Recurrent neural Networks", PhD Thesis, (2013). J. Dean, et al., "Large Scale Distributed Deep Networks", NIPS 2012.
CNNは認識性能において大きな利点があるが、学習時間が長大であるという弱点を併せ持っている。ソーシャルネットワークに関するデータや、自動運転に関係するデータなどは今後とも増加の一途をたどるものの一例であるが、いつか学習時間が膨大になりすぎて、実質的な時間内に学習が終了しない可能性も充分に考えられる。場合によっては、学習に年単位の時間を要することも考えられる。こうなった場合、製品化は現実的ではなく、認識性能で劣るCNN以外のアルゴリズムの使用を余儀なくされることにもなり兼ねない。すなわち、ニューラルネットワーク学習の抜本的高速化は、産業応用にとって極めて重要な課題である。
本発明はこのような問題点に鑑みてなされたものであり、本発明の課題は、より高速にニューラルネットワーク用のパラメタを更新できる学習システムおよび学習方法を提供することである。
本発明の一態様によれば、複数の微分値算出部と、パラメタ更新部とを備え、ニューラルネットワーク用のパラメタを更新する学習システムであって、前記複数の微分値算出部のそれぞれは、互いに同期することなく、前記パラメタ更新部からある時点でのパラメタを受信し、前記受信したパラメタに基づいて、前記パラメタの更新に用いられる微分値を算出し、前記微分値を前記パラメタ更新部に送信し、前記パラメタ更新部は、前記微分値算出部から前記微分値を受信し、前記複数の微分値算出部による微分値算出と同期することなく、受信した前記微分値に基づいて前記パラメタを更新し、更新後の前記パラメタを前記複数の微分値算出部に送信し、前記微分値算出部は、前記微分値を算出する際、前記パラメタの受信時点から、算出した微分値が前記パラメタ更新部によって前記パラメタの更新に用いられるまでの間に前記パラメタが更新される回数に対応するステイルネスを考慮して、前記微分値を算出する、学習システムが提供される。
非同期型の分散方式において、ステイルネスを考慮するため、高速にパラメタを更新できる。
前記微分値算出部は、前記ステイルネスを考慮して、算出した微分値が前記パラメタ更新部によって前記パラメタの更新に用いられる際のパラメタの予測値を算出し、該予測値を微分して前記微分値を算出するのが望ましい。
この場合、前記微分値算出部は、前記ある時点でのパラメタと、前記ある時点より過去の時点でのパラメタと、前記ステイルネスと、に基づいて、前記予測値を算出するのが望ましい。
具体的には、前記ある時点τでのパラメタをW(τ)とし、前記ある時点より過去の時点(τ−1)でのパラメタをW(τ-1)とし、前記ステイルネスをsnとし、γ∈(0,1)とするとき、前記予測値P(τ)は、

であってもよい。
受信したパラメタそのものではなく、更新に用いられる際のパラメタに近い予測値を微分するため、高速にパラメタを更新できる。
前記微分値算出部は、学習が進むにつれて前記γを大きな値としてもよい。
あるいは、前記パラメタ更新部は、前記微分値に学習係数を乗じた値を用いて前記パラメタを更新し、学習が進むにつれて前記学習係数を大きな値としてもよい。
また、前記パラメタ更新部は、学習の開始から所定回パラメタ更新が行われるまでは、前記パラメタの更新量の絶対値またはノルムが所定値を超えないよう、更新量を調整してもよい。
このようにすることで、学習初期に目的関数が不安定となるのを抑制できる。
また、本発明の別の態様によれば、ニューラルネットワーク用のパラメタを更新する学習方法であって、複数の微分値算出部のそれぞれが、互いに同期することなく、パラメタ更新部からある時点でのパラメタを受信するステップと、前記受信したパラメタに基づいて、前記パラメタの更新に用いられる微分値を算出するステップと、前記微分値を前記パラメタ更新部に送信するステップと、前記パラメタ更新部が、前記微分値算出部から前記微分値を受信するステップと、前記複数の微分値算出部による微分値算出と同期することなく、受信した前記微分値に基づいて前記パラメタを更新するステップと、更新後の前記パラメタを前記複数の微分値算出部に送信するステップと、を備え、前記微分値算出部が前記微分値を算出するステップでは、前記パラメタの受信時点から、算出した微分値が前記パラメタ更新部によって前記パラメタの更新に用いられるまでの間に前記パラメタが更新される回数に対応するステイルネスを考慮して、前記微分値を算出する、学習方法が提供される。
ステイルネスを考慮するため、高速にパラメタを更新できる。
CNN構造の一例を模式的に示す図。 NAG法によるパラメタ更新を説明する図。 同期型における微分値算出およびパラメタ更新のタイミングを模式的に説明する図。 非同期型における微分値算出およびパラメタ更新のタイミングを模式的に説明する図。 一実施形態に係る分散方式の学習システムの一構成例を示す概略ブロック図。 本実施形態におけるパラメタ更新を説明する図。 本発明による誤差d1と、従来法による誤差d2とを比較したグラフ。 学習を繰り返し行った際のエラー率の遷移を示すグラフ。
以下、本発明に係る実施形態について、図面を参照しながら具体的に説明する。
図1は、CNN構造の一例を模式的に示す図である。CNNは、1または複数の畳み込み層121およびプーリング層122の組と、多層ニューラルネットワーク構造123とを備えている。初段の畳み込み層121に認識対象(以下、画像データとする)が入力される。そして、多層ニューラルネットワーク構造123から認識結果が出力される。
畳み込み層121は、入力される画像データ(初段の畳み込み層121にあっては認識対象の画像データ、2段目以降の畳み込み層121にあっては後述する特徴マップ)に対してフィルタ21aを適用して畳み込みを行い、次いで非線形写像を行う。フィルタ21aは複数ピクセルの要素を持つ重みであり、各重みはバイアスを含んでいてもよい。
プーリング層122は、畳み込み層121からの画像データの解像度を下げるプーリング操作を行い、特徴マップを生成する。
多層ニューラルネットワーク構造123は、入力層231と、1または複数の隠れ層232と、出力層233とを有する。入力層231には最終段のプーリング層122からの特徴マップが入力される。隠れ層232は重みを用いて積和演算を行う。出力層233はCNN処理の最終結果を出力する。
畳み込み層121におけるフィルタ21aの重みや、隠れ層232における重みがニューラルネットワーク用のパラメタであり、事前に学習しておく必要がある。ここでの学習とは、認識対象の画像データが入力されたときにCNNが理想の出力を返すようパラメタを更新して最適化することであり、具体的には、所定の目的関数が最小値へ収束するまでパラメタの更新を反復的に行う。目的関数とはCNNがどの程度理想的な出力値から離れているかを定量化した関数(例えば二乗誤差やクロスエントロピー)を全学習データ分だけ足し合わせた関数である。目的関数はパラメタの関数であり、目的関数が小さいほどよいCNNであるといえる。
本実施形態では、目的関数を最小化する手法として、次に説明するミニバッチ確率的勾配法と呼ばれる勾配法の一種を踏襲する。学習には、認識の対象となる画像データと、それに対する理想的な出力値の組である学習データが多数用いられる。ミニバッチ確率的勾配法では、1回のパラメタ更新に全学習データを用いるのではなく、一部の学習データを用いる。ミニバッチとは1回のパラメタ更新に用いられる学習データの集合を指し、ミニバッチサイズとはミニバッチを構成する学習データの数を指す。
ミニバッチ確率的勾配法では、まず、全学習データからランダムに学習データを取り出してミニバッチを作成する。次いで、そのミニバッチを使って目的関数のパラメタに対する微分値を算出する。そして、微分値を用いてパラメタを更新する。以下、より詳しく説明する。
目的関数をJ(x;W)とする。ここで、xは入力データであり、Wはパラメタの集合である。t回目の反復におけるパラメタをW(t)とすると、ミニバッチ確率的勾配法において、t+1回目のパラメタを得る更新式は次式によって与えられる。
ここで、xr(i,t)はデータセットの中のr(i,t)番目のデータを指し、r(i,t)は時刻(反復回数の順番を示す番号のことを時刻とも表現する)tのパラメタ使用時においてi番目にランダムサンプリングしたデータインデックスを指す。ηは学習係数と呼ばれる正の数である。mはミニバッチサイズである。右辺第2項が微分値であり、そのまわりの括弧([])の右下には、パラメタ空間におけるポイントが示されており、このポイントにおいて微分が計算されることを意味している。更新式(1)においては、時刻tにおけるパラメタW(t)について括弧([])内に示す微分を行うことを示している。
更新式(1)はミニバッチを用いた最急降下法として知られている。つまり、右辺第2項のベクトルはミニバッチによって規定される目的関数Jを最も急峻に下る方向を持っている。
このようなミニバッチ確率的勾配法を用いて大規模なデータセットを学習することで、多層ニューラルネットワークは高い汎化性能(未知データに対する推定の精度)を獲得できる。
ミニバッチ確率的勾配法において、学習の収束速度を高める手法が既に検討されており、よく知られている2つの従来技術について説明する。
[モメンタム法](非特許文献2参照)
パラメタ更新式は次式で与えられる。
Vはモメンタム項と呼ばれる項であり、過去に算出された目的関数の微分の加重平均である。荷重の度合いはモメンタム係数γ∈(0,1)によって調整され、通常0.9などの値が設定される。この更新式において、微分値の時系列相関がある場合、更新式において、モメンタム項Vが主要項であり、微分値が補正項であるとみなすことができる。
モメンタム法によれば、微分値が時刻tの近傍において高い相関を持つ場合に、一回あたりのパラメタ更新のステップ幅を大きく取れ、収束までの時間を短縮できる。
[ネステロフの加速勾配法(Nesterov Accelerated Gradient、以下NAG法という)](非特許文献3,4参照)
パラメタ更新式は次式で与えられる。
モメンタム法との違いは、パラメタ空間において微分値を算出するポイントである。モメンタム法では時刻tにおけるパラメタW(t)を用いて微分値を算出するのに対し、NAG法ではW(t)+γV(t)を用いて微分値を算出する(微分値まわりの括弧([])の右下参照)。すなわち、NAG法では、まず過去のモメンタム項Vを主要項として加算し、その加算結果に対して補正項である微分値を算出して加算する。
NAG法について、図2を用いてより詳しく説明する。同図は、時刻(t−1)におけるパラメタW(t-1)(符号11)を更新して時刻tにおけるパラメタW(t)(符号12)が得られた後に、さらに時刻(t+1)におけるパラメタW(t+1)を得ることを説明している。上記(3)式より、パラメタW(t+1)は次式で表される。なお、目的関数Jの引数は省略している。
ここで、P(t)=W(t)+γV(t)とした。図2において、パラメタW(t-1)からパラメタW(t)に向かうベクトル(符号13)がモメンタム項V(t)である。パラメタW(t)に主要項である重み付きのモメンタム項γV(t)(符号14)を加算したポイントがP(t)(符号15)である。
そして、主要項による変更を受けたP(t)において、補正項である微分値

を算出する。P(t)に重み付きの微分値(符号16)


を加算することで、時刻(t+1)におけるパラメタW(t+1)(符号17)が得られる。
多くの場合において、モメンタム法およびNAG法とも確率的勾配法を加速できることが経験的に知られている。以上のようなモメンタム法あるいはNAG法によって、ある程度学習の高速化が可能である。しかしながら、特にデータの数が非常に多い場合においては、まだ十分な高速化が実現できているとは言い難い。
そこで、本発明では、CPUやGPUから構成される複数の計算機(ノードともいう)を高速な通信回線で接続した計算機クラスタを利用し、学習に必要な計算を分散処理させる方式(以下、分散方式という)を採用することで、学習時間をさらに短縮することを図る。
(第1の実施形態)
分散方式では、(1)「何を」通信するのか、(2)「何と」通信するのか、(3)「いつ」通信するのか、という観点での分類が可能である。
まず、「『何を』通信するのか」という観点では、「モデル並列」と「データ並列」の方式がある。モデル並列方式では、モデル自体が計算機間で分散化されており、ニューラルネットワークの中間変数が通信される。データ並列方式では、モデルの計算は単体の計算機に閉じており、各計算機で計算された微分値などが通信される。
データ並列方式は各計算機に異なるデータを処理させるため、多数のデータを一括に処理できる。ミニバッチ確率的勾配法を前提とする場合、ミニバッチ内でのデータ並列が理にかなうことから、本明細書ではデータ並列を主に想定している。
モデル並列方式はメモリに収まらないような巨大なニューラルネットワークを扱うときに有用である。よってモデル並列方式は用途が特殊であるため、本明細書では説明しないが、本発明は後述する「古いパラメタ」の使用において効果を有するものであり、モデル並列方式にもデータ並列方式にも適用可能である。
次に、「『誰と』通信するのか」という観点では、パラメタを管理する「パラメタサーバ」と微分計算を行う「ワーカノード」との間で一対一通信を使う方式(例えば非特許文献5)と、パラメタサーバを設けず全ワーカノード間で全体通信(全対全の通信)を使う方式とがある。前者では、ワーカノードはパラメタサーバとのみ通信し、ワーカ−ノード間の通信は原則行わない。本発明はどちらの方式でも適用可能である。
そして、「『いつ』通信するのか」という観点では、同期型と非同期型の方式がある。
図3は、同期型における微分値算出およびパラメタ更新のタイミングを模式的に説明する図である。同期型では、ノードどうしが同期して微分計算を行う。具体的には、最も計算が遅いノードが微分値を算出するのを待ち、全ノードからの微分値を用いてパラメタを更新し、更新された新しいパラメタを用いて全ノードが次の微分計算を一斉に開始する、というのが基本的な流れである。このように同期型では、ノードに待ちが発生するという非効率さがあり、更新頻度は低い。
一方、同期型では、微分値算出とパラメタ更新の順列性が守られるため、通常のミニバッチ確率法における更新式は次式で表される。
kは計算ノードが一括に処理するデータの数である。r(i,n,t)はノードnが時刻tのパラメタ使用時においてi番目にランダムサンプリングしたデータインデックスを指す。#nodesは微分値の算出を行うノードの台数である。単一ノード使用時における更新式(1)との違いは、微分をノード台数分足し合わせている点にすぎない。この時、ミニバッチサイズは#nodes×kで与えられる。
よって、同期型では、ミニバッチを用いた最急降下法である更新式(1)が厳密に実行され、目的関数J(x;W)は単調に減少する。また、モメンタム法あるいはNAG法を適用した場合においても上記更新式(2)または(3)が厳密に実行される。したがって、いずれにしても同期型では収束保障があり、1回の更新における目的関数J(x;W)の降下速度は大きい。
図4は、非同期型における微分値算出およびパラメタ更新のタイミングを模式的に説明する図である。非同期型では、ノードどうしが非同期で微分値算出およびパラメタ更新を行う。非同期型では、全ノードが微分値の算出を終了するのを待つことなくパラメタ更新を行うし、各ノードが互いに同期することなく、間断なく微分値の算出を反復する。このため、ノードに待ちが発生せず高効率に稼働させることでき、更新頻度は高い。
一方で、非同期型では、あるノードが微分値を算出している途中でパラメタ更新が行われることを許容するため、各ノードが保持しているパラメタは徐々に古くなっていく。すなわち、微分値を算出する際に古いパラメタを用いることが、同期型とは大きく異なる。このことを考慮し、ミニバッチ確率法における更新式は例えば次式で表される。
同期型における更新式(4)との違いは、微分値の算出を行うパラメタの時刻τが現在の時刻tよりも古いことである。τはノード毎に異なっていてもよい。なお、本明細書では、時刻に遅れのあるパラメタが微分値算出に使用されているものは、たとえ大多数の計算機が同期型の更新則を使用していたとしても、全て非同期型に分類する。
パラメタの古さのため、更新式(5)における第2項のベクトルは最急降下の方向を持つことを保障できない。したがって、非同期型における更新式(5)は厳密な最急降下法ではなく、1回の更新あたりの目的関数J(x;W)の降下速度は同期型に比べると小さい。ただし、理論的収束保障があるわけではないが、実質的に問題はないとも報告されている(非特許文献5)。
以上のように、通常の同期型の場合は、1回の更新における目的関数J(x;W)の降下速度が大きいというメリットがある一方、更新頻度が低いというデメリットがある。これに対し、通常の非同期型の場合は、更新頻度が高いというメリットがある一方、1回の更新における目的関数J(x;W)の降下速度が小さいというデメリットがある。
したがって、非同期型において、1回の更新における目的関数J(x;W)の降下速度を大きくできれば、非同期型の学習を高速化できる。そこで、本発明では、非同期型の分散方式を前提とし、次のように従来のNAG法をステイルネスのある学習に拡張する。
図5は、一実施形態に係る分散方式の学習システムの一構成例を示す概略ブロック図である。学習システムは互いに通信可能な複数のノード1から構成され、そのそれぞれが微分値算出部2と、パラメタ更新部3とを有する。微分値算出部2は例えばGPU(Graphics Processing Unit)であり、ある時刻におけるパラメタの微分値を算出する。パラメタ更新部3は例えばCPU(Central Processing Unit)であり、算出された微分値を用いて勾配法によりパラメタの更新を行う。
非同期型の分散方式においては、学習システムが複数の微分値算出部2を備え、そのそれぞれが互いに同期することなく、パラメタ更新部3からパラメタを受信して微分値を算出し、パラメタ更新部3に送信する。また、パラメタ更新部3は微分値を受信し、微分値算出部2による微分値算出とは同期することなく、言い換えると、全微分値算出部2からの微分値受信を待つことなく、受信した微分値を用いてパラメタ更新を行い、各微分値算出部2に送信する。
非同期型の分散学習方法では、微分値算出において古いパラメタの使用を余儀なくされる。古いパラメタで計算した微分値を使用して新しいパラメタを更新するが、その微分値は本来使用すべき微分値(新しいパラメタで計算された微分)とは厳密には一致しないことが、1回の更新における目的関数Jの降下速度を小さくしている。微分を計算する際のパラメタが古ければ古いほど、より微分の近似の精度が悪くなり、1回の更新あたりの目的関数Jの降下速度が減少すると考えられる。
つまり、微分値を算出するのに使ったパラメタW(τ)と、その微分値を使って更新を実施する際のパラメタW(t)との距離(例えば||W(τ)−W(t)||2)が、微分の近似精度を特徴づける1つの指標と考えられる。本発明では、この距離を小さくすることで、微分の近似精度を高める。
微分値算出部2がパラメタを受信した際に、微分値を受け渡す未来の時点でのパラメタを予測し、予測したポイントで微分値算出を行うことを考える。ただし、予測そのものの計算に多大な時間が掛かっていては、学習が高速化されないため、できる限り少ない演算量で予測を行う必要がある。
以下、非同期型の確率的勾配法において、時刻tからt+1へのパラメタ更新を考える。なお、微分値を提供する微分値算出部2が保持しているパラメタの時刻番号τは当然時刻t以下である。以下、完全非同期(微分値算出とパラメタ更新の同期が一切行われない方式)の場合を考える。パラメタの古さをステイルネス(staleness)sと呼ぶ。ステイルネスsは、微分値算出部2がパラメタを受信した時点から、微分値算出部2が算出した微分値がパラメタ更新部3によってパラメタの更新に用いられるまでの間にパラメタが更新される回数である。例えば、微分値算出周期とパラメタ更新周期とが一致する非同期システムにおいては、常に2期の遅れがあるためs=2である。
本発明では、更新式を次のようにする。
nはn番目の微分値算出部2におけるステイルネスである。ステイルネスsnは微分値算出部2ごとにパラメタの古さを観測した上で決定してもよい。あるいは、全微分値算出部2においてステイルネスがほぼ共通する場合、snは平均ステイルネスの切り下げ、切り上げまたは四捨五入によって固定値としてもよい。
この更新式(6)は、非同期型の分散システムにNAG法を適用したものである。ただし、単に適用したというだけではなく、ステイルネスsnを考慮した点が本発明の大きな特徴である。更新式(6)では、重み更新の主要項も用いてステイルネスsn分未来のパラメタを予測している。より具体的には、現時刻のモメンタムに現在から順番にγ∈(0,1)のべき乗の重みを付けたものを足し合わせた

を、時刻τのパラメタW(τ)に加算したものが、P(τ)である。P(τ)を算出することは、時刻τにおいて、時刻tから更に更新の主要項であるモメンタム項による変更を受けたパラメタを予測することに対応しており、この意味において本明細書ではP(τ)を予測値とも称する。すなわち、本発明における予測値P(τ)はステイルネスsnに依存し、下式で表される。
図6を用いてより詳しく説明する。同図では、微分値算出周期がパラメタ更新周期と等しく、したがって、ステイルネスsn=2(つまり、t−τ=2)である例を示している。
図6(a)はパラメタ更新部3の処理を説明している。例えば、パラメタ更新部3が時刻(τ−2)〜(τ−1)にかけてパラメタ更新を行い、時刻(τ−1)においてパラメタW(τ-1)が得られている。以下同様に、時刻τ,(τ+1),(τ+2),(τ+3)において、パラメタWτ,W(τ+1),W(τ+2),W(τ+3)がそれぞれ得られている。
図6(b)は、ある1つの(番号nの)微分値算出部2がパラメタW(τ)を受信し、微分値を算出してパラメタ更新部3に送信することを示している。
そして、図6(c)は、時刻tから時刻(t+1)へのパラメタ更新を示している。より具体的には、微分値算出部2がパラメタW(τ)を使って微分値を算出し、この微分値を使ってパラメタ更新部3がパラメタW(t)を更新してパラメタW(t+1)を得る様子を示している。なお、s=2であるから、W(t)=W(τ+2)であり、W(t+1)=W(τ+3)である。
詳細には次の通りである。微分値算出部2はパラメタW(τ)(符号22)を受信すると、モメンタム項V(τ)(符号21)を用いて、上記(7)式に基づいて予測値P(τ)(符号23)を算出する。この予測値P(τ)はパラメタW(τ-1)からパラメタW(τ)に向かう直線上にある。すなわち、微分値算出部2は、ある時点τのパラメタW(τ)と、その前の時点(τ−1)でのパラメタW(τ-1)とから、線形に予測値P(τ)を算出する。
なお、ステイルネスsnが大きいほど、予測値P(τ)とパラメタW(τ)との距離は大きくなる。ステイルネスsnが大きいほど、すなわち時刻τと時刻tとの差が大きいほど、時刻tにおいてパラメタW(τ)が大きく更新されているためである。
続いて、微分値算出部2は予測値P(τ)の微分値

を算出する(符号24の矢印、学習係数ηは省略した)。この予測値P(τ)は後述する符号23’で示すポイントを予測したものである。
このようにして微分値算出部2が微分値の算出を完了した時点で、パラメタ更新部3はパラメタW(τ+1)の算出を終えており、W(τ+2)を得るべくW(τ+1)を更新している最中である。その後、時刻(τ+2)(=τ+sn)でパラメタW(τ+2)が得られる(図6(a),(b)参照)。
その後、パラメタ更新部3は微分値算出部2から上記微分値を受信する。そして、パラメタ更新部3は、パラメタW(τ+2)にモメンタム項γV(τ+2)を加算したポイント(符号23’)に微分値(符号24’の矢印)をさらに加算してパラメタW(τ+2)を更新し、パラメタW(τ+3)を得る(符号25)。
このように、微分値算出部2は、受信した時刻τにおけるパラメタW(τ)をそのまま微分するのではなく、予測値P(τ)を算出してその微分値を算出する。本発明ではポイント23’を予測しており、その予測値がP(τ)(符号23)である。すなわち、本来であればポイント23’での微分を用いるべきところ、予測値P(τ)での微分を用いており、これらの距離d1が誤差である。一方、このような予測を行わない従来法では、本来であればパラメタW(t)での微分を用いるべきところ、パラメタW(τ)での微分を用いており、これらの距離d2が誤差となる。
図7は、ステイルネスsが45である非同期型分散学習システムにおいて、本発明による誤差d1と、従来法による誤差d2とを比較したグラフである。なお、従来法は非同期のモメンタム法とし、パラメタW(τ)を微分値算出に使用した。図示のように、従来法における誤差d2に比べ、本発明では誤差d1を1/34程度まで小さくすることができている。すなわち、微分値を算出するのに使うパラメタW(τ)と、その微分値を使って更新を実施するパラメタW(t)との距離(誤差)を小さくすることができ、微分の近似精度を高めることができる。
なお、更新式(7)では、2つの連続するパラメタW(τ-1),W(τ)から線形に予測値P(τ)を算出する例を示した。しかしながら、微分値算出部2は、時刻τ以前の連続または非連続の任意の数のパラメタを用いて線形または非線形に予測値P(τ)を算出してもよい。
例えば、微分値算出部2は、3つの連続するパラメタW(τ-2),W(τ-1),W(τ)から2次関数を用いて予測値P(τ)を算出してもよい。この場合、予測値P(τ)は3点W(τ-2),W(τ-1),W(τ)を通る2次関数上にあり、パラメタW(τ)と予測値P(τ)との距離がステイルネスsに応じて定まる。いずれにしても、微分値算出部2は、少なくとも2つのパラメタ(ある時点のパラメタと、それより過去の時点のパラメタ)に加え、ステイルネスsも用いて予測値P(τ)を算出すればよい。
別の例として、線形予測であるが、下式に基づいて予測値P(τ)を算出してもよい。

ここで、cはモメンタム係数であり、例えば次のようにして最適な値が設定される。本番の学習に先立って予備の学習を実施し、パラメタの時系列変化を取得し保存しておく。すると、図6(c)に示す誤差d1を最小とするモメンタム係数cを線形最小二乗法によって推定できる。本番の学習と同様の学習システムで予備の学習を行うことで、ステイルネスが反映された最適なモメンタム係数cが得られる。あるいは、本番の学習時にモメンタム係数cを推定してもよい。このようにして設定されたモメンタム係数cを本番の学習で使用することによって高速化を図れる。
このように、第1の実施形態では、非同期型の分散処理において、ステイルネスsを考慮して微分値を算出する。そのため、非同期であっても、すなわち微分を行う時点のパラメタと、パラメタ更新時のパラメタとが一致しない場合であっても、微分の近似精度が向上する。結果として、目的関数の降下速度が大きくなり、高速にパラメタを更新できる。
(第2の実施形態)
第1の実施形態の改善について説明する。学習の初期においては、微分値の絶対値が大きい値となることもある。そうすると、学習の初期においてパラメタの更新が不安定になりかねない。具体的には、目的関数の変動が大きくなりすぎて降下速度が小さくなることや、目的関数の値が無限大に発散してしまうなどといったことが起こらないとも限らない。
経験上、こういった不安定性は学習初期に特有であり、学習の中盤や後半では問題にはならない。学習初期の不安定性の改善には、次のような方法を採用するのが望ましい。
第1例として、パラメタ更新部3が微分値に学習係数ηを乗じた値を用いてパラメタ更新を行う際、学習係数ηを学習初期は小さい値としておき、学習が進むにつれて徐々に大きくして目的の値まで変化させることが考えられる。これはパラメタ更新のステップ幅を学習初期において低減することを意味しており、予測するパラメタの範囲を比較的小さい範囲にとどめることが出来る。結果として不安定性改善に効果的である。
第2例として、微分値算出部2が学習初期はモメンタム係数γを小さい値としておき、学習が進むにつれて徐々に大きくして目的の値まで変化させることが考えられる。この場合、学習初期においては、モメンタム項Vの効果が小さく、非同期の基本的な更新式がより近似的に実現されるため、安定性が期待できる。
第1例および第2例とも、微分値が「暴れる」可能性がある学習初期を通過した後は目的の値に達することが望ましい。
第3例として、学習の初期は安定的動作が可能な学習法(例えば同期型の学習法)を事前学習として適用し、学習初期を通過した後は、上述した第1の実施形態に係るパラメタ更新を行ってもよい。
第4例として、第1の実施形態に係るパラメタ更新を学習初期から適用するが、得られたパラメタの更新量の絶対値がある値を超えないように閾値操作を適用してもよい。
第5例として、第1の実施形態に係るパラメタ更新を学習初期から適用するが、得られたパラメタの更新量のノルムがある値を超えないように、更新量をリスケールしてもよい。
なお、第1〜第5例のうちの任意の2以上を組み合わせてもよい。
このように、第2の実施例では、学習初期に目的関数が不安定となるのを抑制できる。
図8は、学習を繰り返し行った際のエラー率の遷移を示すグラフである。横軸は時刻あるいは更新回数である。縦軸はエラー率すなわち認識の不正解率である。曲線f1が第1の実施形態に第2の実施形態における第1例を適用した結果である。曲線f2〜f4は参考であり、順に、非同期型モメンタム法、同期型モメンタム法および同期型NAG法の結果である。ミニバッチサイズをほぼ共通(約256)とし、CNN構造は、10層の畳み込み層、1層の全結合層および5層のプーリング層とした。
図示のように、例えばエラー率20%を目標とすると、本発明に基づく曲線f1は他の曲線f2〜f4より短時間(曲線f3の約半分)で目標に達した。このことから、本発明の有用性が示された。
上述した実施形態は、本発明が属する技術分野における通常の知識を有する者が本発明を実施できることを目的として記載されたものである。上記実施形態の種々の変形例は、当業者であれば当然になしうることであり、本発明の技術的思想は他の実施形態にも適用しうることである。したがって、本発明は、記載された実施形態に限定されることはなく、特許請求の範囲によって定義される技術的思想に従った最も広い範囲とすべきである。
1 ノード
2 微分値算出部
3 パラメタ更新部

Claims (8)

  1. 複数の微分値算出部と、パラメタ更新部とを備え、ニューラルネットワーク用のパラメタを更新する学習システムであって、
    前記複数の微分値算出部のそれぞれは、互いに同期することなく、
    前記パラメタ更新部からある時点でのパラメタを受信し、
    前記受信したパラメタに基づいて、前記パラメタの更新に用いられる微分値を算出し、
    前記微分値を前記パラメタ更新部に送信し、
    前記パラメタ更新部は、
    前記微分値算出部から前記微分値を受信し、
    前記複数の微分値算出部による微分値算出と同期することなく、受信した前記微分値に基づいて前記パラメタを更新し、
    更新後の前記パラメタを前記複数の微分値算出部に送信し、
    前記微分値算出部は、前記微分値を算出する際、前記パラメタの受信時点から、算出した微分値が前記パラメタ更新部によって前記パラメタの更新に用いられるまでの間に前記パラメタが更新される回数に対応するステイルネスを考慮して、前記微分値を算出する、学習システム。
  2. 前記微分値算出部は、前記ステイルネスを考慮して、算出した微分値が前記パラメタ更新部によって前記パラメタの更新に用いられる際のパラメタの予測値を算出し、該予測値を微分して前記微分値を算出する、請求項1に記載の学習システム。
  3. 前記微分値算出部は、前記ある時点でのパラメタと、前記ある時点より過去の時点でのパラメタと、前記ステイルネスと、に基づいて、前記予測値を算出する、請求項2に記載の学習システム。
  4. 前記ある時点τでのパラメタをW(τ)とし、前記ある時点より過去の時点(τ−1)でのパラメタをW(τ-1)とし、前記ステイルネスをsnとし、γ∈(0,1)とするとき、前記予測値P(τ)は、

    である、請求項3に記載の学習システム。
  5. 前記微分値算出部は、学習が進むにつれて前記γを大きな値とする、請求項4に記載の学習システム。
  6. 前記パラメタ更新部は、前記微分値に学習係数を乗じた値を用いて前記パラメタを更新し、学習が進むにつれて前記学習係数を大きな値とする、請求項1乃至5のいずれかに記載の学習システム。
  7. 前記パラメタ更新部は、学習の開始から所定回パラメタ更新が行われるまでは、前記パラメタの更新量の絶対値またはノルムが所定値を超えないよう、更新量を調整する、請求項1乃至6のいずれかに記載の学習システム。
  8. ニューラルネットワーク用のパラメタを更新する学習方法であって、
    複数の微分値算出部のそれぞれが、互いに同期することなく、
    パラメタ更新部からある時点でのパラメタを受信するステップと、
    前記受信したパラメタに基づいて、前記パラメタの更新に用いられる微分値を算出するステップと、
    前記微分値を前記パラメタ更新部に送信するステップと、
    前記パラメタ更新部が、
    前記微分値算出部から前記微分値を受信するステップと、
    前記複数の微分値算出部による微分値算出と同期することなく、受信した前記微分値に基づいて前記パラメタを更新するステップと、
    更新後の前記パラメタを前記複数の微分値算出部に送信するステップと、を備え、
    前記微分値算出部が前記微分値を算出するステップでは、前記パラメタの受信時点から、算出した微分値が前記パラメタ更新部によって前記パラメタの更新に用いられるまでの間に前記パラメタが更新される回数に対応するステイルネスを考慮して、前記微分値を算出する、学習方法。
JP2016253169A 2016-12-27 2016-12-27 学習システムおよび学習方法 Active JP6704583B2 (ja)

Priority Applications (3)

Application Number Priority Date Filing Date Title
JP2016253169A JP6704583B2 (ja) 2016-12-27 2016-12-27 学習システムおよび学習方法
US15/795,691 US11521057B2 (en) 2016-12-27 2017-10-27 Learning system and learning method
CN201711425028.XA CN108241889A (zh) 2016-12-27 2017-12-25 学习***与学习方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
JP2016253169A JP6704583B2 (ja) 2016-12-27 2016-12-27 学習システムおよび学習方法

Publications (2)

Publication Number Publication Date
JP2018106489A JP2018106489A (ja) 2018-07-05
JP6704583B2 true JP6704583B2 (ja) 2020-06-03

Family

ID=62629916

Family Applications (1)

Application Number Title Priority Date Filing Date
JP2016253169A Active JP6704583B2 (ja) 2016-12-27 2016-12-27 学習システムおよび学習方法

Country Status (3)

Country Link
US (1) US11521057B2 (ja)
JP (1) JP6704583B2 (ja)
CN (1) CN108241889A (ja)

Families Citing this family (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11043205B1 (en) * 2017-06-27 2021-06-22 Amazon Technologies, Inc. Scoring of natural language processing hypotheses
JP7208758B2 (ja) * 2018-10-05 2023-01-19 株式会社デンソーアイティーラボラトリ 学習方法および学習システム
DE102019002790B4 (de) 2019-04-16 2023-05-04 Mercedes-Benz Group AG Verfahren zur Prädiktion einer Verkehrssituation für ein Fahrzeug
KR20210020387A (ko) 2019-08-14 2021-02-24 삼성전자주식회사 전자 장치 및 그 제어 방법
CN111626434B (zh) * 2020-05-15 2022-06-07 浪潮电子信息产业股份有限公司 一种分布式训练参数更新方法、装置、设备及存储介质
KR102555268B1 (ko) * 2020-12-14 2023-07-13 한양대학교 산학협력단 파라미터 서버 기반 비대칭 분산 학습 기법

Also Published As

Publication number Publication date
JP2018106489A (ja) 2018-07-05
US11521057B2 (en) 2022-12-06
US20180181863A1 (en) 2018-06-28
CN108241889A (zh) 2018-07-03

Similar Documents

Publication Publication Date Title
JP6704583B2 (ja) 学習システムおよび学習方法
US11941527B2 (en) Population based training of neural networks
US20140214735A1 (en) Method for an optimizing predictive model using gradient descent and conjugate residuals
CN108764568B (zh) 一种基于lstm网络的数据预测模型调优方法及装置
Baram et al. Model-based adversarial imitation learning
US11449731B2 (en) Update of attenuation coefficient for a model corresponding to time-series input data
US11100388B2 (en) Learning apparatus and method for learning a model corresponding to real number time-series input data
US20220108215A1 (en) Robust and Data-Efficient Blackbox Optimization
Dong et al. Towards adaptive residual network training: A neural-ode perspective
US20180314978A1 (en) Learning apparatus and method for learning a model corresponding to a function changing in time series
CN114065863A (zh) 联邦学习的方法、装置、***、电子设备及存储介质
CN113487039A (zh) 基于深度强化学习的智能体自适应决策生成方法及***
CN113935489A (zh) 基于量子神经网络的变分量子模型tfq-vqa及其两级优化方法
Karda et al. Automation of noise sampling in deep reinforcement learning
El-Laham et al. Policy gradient importance sampling for Bayesian inference
Sun et al. Time series prediction based on time attention mechanism and lstm neural network
Banerjee et al. Boosting exploration in actor-critic algorithms by incentivizing plausible novel states
KR20200000660A (ko) 실시간 시계열 데이터를 위한 예측 모형 생성 시스템 및 방법
JPWO2019142241A1 (ja) データ処理システムおよびデータ処理方法
Yang Kalman optimizer for consistent gradient descent
Ma et al. Long-Term Credit Assignment via Model-based Temporal Shortcuts
CN113807005B (zh) 基于改进fpa-dbn的轴承剩余寿命预测方法
Guo et al. Optimal control of blank holder force based on deep reinforcement learning
CN114139677A (zh) 一种基于改进gru神经网络的非等间隔时序数据预测方法
Zheng et al. Variance reduction based partial trajectory reuse to accelerate policy gradient optimization

Legal Events

Date Code Title Description
A621 Written request for application examination

Free format text: JAPANESE INTERMEDIATE CODE: A621

Effective date: 20190415

A977 Report on retrieval

Free format text: JAPANESE INTERMEDIATE CODE: A971007

Effective date: 20200309

TRDD Decision of grant or rejection written
A01 Written decision to grant a patent or to grant a registration (utility model)

Free format text: JAPANESE INTERMEDIATE CODE: A01

Effective date: 20200324

A61 First payment of annual fees (during grant procedure)

Free format text: JAPANESE INTERMEDIATE CODE: A61

Effective date: 20200330

R150 Certificate of patent or registration of utility model

Ref document number: 6704583

Country of ref document: JP

Free format text: JAPANESE INTERMEDIATE CODE: R150

R250 Receipt of annual fees

Free format text: JAPANESE INTERMEDIATE CODE: R250

R250 Receipt of annual fees

Free format text: JAPANESE INTERMEDIATE CODE: R250