CartPoleでDQN(deep Q-learning)、DDQNを実装・解説【Phythonで強化学習:第2回】

Deeplearningを用いた強化学習手法であるDQNとDDQNを実装・解説します。

学習対象としては、棒を立てるCartPoleを使用します。

前回記事では、Q-learning(Q学習)で棒を立てる手法を実装・解説しました。

CartPoleについて詳細は、こちらを御覧ください。

CartPoleでQ学習(Q-learning)を実装・解説【Phythonで強化学習:第1回】
強化学習で倒立振子(棒を立て続ける)制御を実現する方法を実装・解説します。本回ではQ学習(Q-learnin...

Q学習では、行動価値関数Q(s_t, a_t)をテーブル(表)で実現しました。

DQN(Deep Q-learning)では、この関数Qをディープラーニングを用いたディープニューラネットワークで表します。

本記事でははじめに、DQNについて説明します。

ここでは、なぜDQNが必要とされたのかを説明し、その後DQNの4つの工夫点を紹介します。

そしてDDQNについて説明します。

その後、コードを紹介しながら、実装方法と内容を解説します。

DQNとは

前回実装した、テーブル・表を用いたQ学習には困ったことが2つあります。

1つ目は、状態s(t)を離散化する必要があることです。

カートの位置x(t)などは連続値ですが、表で表すためには、適当な範囲内で分割してあげる必要があります。

こうした連続空間で表現される状態変数の離散化は精度を下げる要因となります。

2つ目の困った点は、画像などを入力できない点です。

棒立てのCartPoleくらいであれば状態s(t)を簡単に決めれますが、例えばテトリスなんかだとちょっとしんどい気がします。

画像を直接入力して強化学習できれば、嬉しいです。

このような2つの困った点を解決するのに、

「DL(ディープラーニング)使えばよくない?(とくにCNN)」

って流れになり、Q関数にディープニューラルネットワークを使用する手法が生まれました。

Qネットワーク

最初はQ関数にディープニューラルネットワークを用いるというと、ぱっとイメージが分かないのですが、入力層は状態空間の次元数になります。

CartPoleであれば、カートの位置、速度、棒の角度、角速度の4次元なので、入力層のニューロン数は4つです。

そして各ニューロンに、カートの位置x(t)などの連続値をそのまま入力します。

その後、多段のニューラルネットワークを経て出力層に至ります。

出力層のニューロンの数は行動の選択肢数です。

CartPoleの場合は右か左にCartを押すので、出力層は2つのニューロンになります。

そして2つのニューロンがそれぞれのQ(s_t, a_t)の値

すなわち、状態s(t)で行動a(t)をとった場合にその後得られる報酬の総計

を出力します。

これでQネットワークの入力と出力の関係が分かりました。

次に学習の説明です。

このQネットワークに状態s(t)を入力すれば、出力ニューロン1はQ(s_t, 右に押す)、出力ニューロン2はQ(s_t, 左に押す)という値を出力します。

これらの値が正確であれば嬉しいのですが、学習途中では間違った値となっています。

そこで、その時点で正しそうなQ関数の値を求めます。

時刻tで右に押したとすると、実際にもらった報酬r(t)と、その結果状態がs(t+1)になった場合には、max[Q(s_{t+1},a_{t+1})]に時間割引γを掛けたものの和

すなわち、r(t)+γ・max[Q(s_{t+1},a_{t+1})]

が、現状でもっとも正しそうなものとなります。

よって、現状のQネットワークの出力Q(s_t, 右に押す)が、

r(t)+γ・max[Q(s_{t+1},a_{t+1})]

の値に近くなるように、ネットワークの重みを学習してあげます。

重みを学習するとは、

Q(s_t, 右に押す) と r(t)+γ・max[Q(s_{t+1},a_{t+1})]

の誤差を減らす方向に、バックプロパゲーションで各層の各結合の重みを少し変化させるということになります。

以上で、Q関数にディープラーニングを取り入れることができました。

ですが、このまま実際に強化学習をやってもうまくいきません。

DQNにするためには次の4つの工夫が必要です。

DQNの手法

DQNは単純にQ関数をDL(ディープラーニング)にする以外に、4つの工夫があります。

DLで時系列情報を学習するときに、各stepごとに学習すると、時間方向の相関が強くでてしまい、うまく学習ができないという問題があります。

そこで、Experience Replayという手法が実装されます。

逐次順番に学習するのではなく、{s(t), a(t), r(t), s(t+1)}をたくさんメモリに保持しておいて、あとでランダムに学習するという手法をとります。

※r(t)は時刻tでもらえる報酬です。

工夫の2つ目はFixed Target Q-Networkです。

これは、Qネットワークを学習する際に、max[Q(s_{t+1},a_{t+1})]を同じQネットワークから求めるのではなく、少し前の固定しておいたQネットワークを使用するという方法です。

工夫1, 2はメモリからランダムに取り出した学習データをバッチ学習すれば、自然と実現されます。

工夫3は報酬のclippingです。

各ステップで得られる報酬を-1, 0, 1のいずれかに固定しておく方法です。

こうすることで、ゲーム内容(学習対象)によらず、同じハイパーパラメータのDLを使用できるというメリットがあります。

最後の工夫4が誤差関数の工夫です。

DLでは現状のQネットワークの出力Q(s_t, 右に押す)が、

r(t)+γ・max[Q(s_{t+1},a_{t+1})]

の値に近くなるように、ネットワークの重みを学習してあげます。

このときに二乗誤差{r(t)+γ・max[Q(s_{t+1},a_{t+1})] – Q(s_t, 右に押す)}^2

を使用してバックプロパゲーションを計算するのが一般的です。

ですが、この二乗誤差が1以上になる場合には、

abs{r(t)+γ・max[Q(s_{t+1},a_{t+1})] – Q(s_t, 右に押す)}

を使用することにします。

こうすることで学習が安定しやすいというメリットがあります。

参考:DQNをKerasとTensorFlowとOpenAI Gymで実装する

なお、このような誤差関数(損失関数)はHuber関数と呼ばれます。

DQNでは、以上の4点を工夫として、実装してあげます。

DDQN

最後にDDQN(Double DQN)を説明します。

Q学習のときに行動選択にQ関数をFixed Target Q-Networkで1バッチ分前の関数を使用していましたが、もっと前の別のQ関数を使用する手法をDDQNと言います。

今回は、1試行前のQネットワークを使用することでDDQNを実現します。

DDQNの詳細はこちらが参考になります

introduction to double deep Q-learning(日本語)

DQN、DDQNの実装とコード解説

紹介しました4つの工夫を念頭に、コードを説明していきます。

DQN、DDQNは同じひとつのファイルで、途中の設定パラメータで切り替えています。

実装にはこれらを参考にしました。

OpenAI Gymの CartPole問題をDQNで解いた

udacity/deep-learning

AI-blog/CartPole-DQN.py

最初は必要なライブラリのインポートです。

今回はディープラーニングのライブラリとして、Kerasを使用しています。

続いて、1つの関数と、3つのクラスを定義します。

工夫点4のHuber損失関数を定義しています。

Kerasには、Huber関数は入っていないので、自分で定義します。

誤差が±1以上の場合で場合分けして、二乗誤差と絶対誤差の小さい方を使用します。

Kerasではバッチ内の施行ごとに異なる損失関数を設定できないので、tensorflowのwhere関数を使用します。

Q関数として使用するディープニューラルネットワークのクラスを定義します。

今回は入力層+中間層2層+出力層です。

画像で示すとこんな感じです。

そして、Qネットワークの各重みを学習する関数を定義します。

if notの部分は、次の状態s(t+1)が、終了でない(こけたり、200step経過)していないときは教師信号となる報酬はr(t)+γ・max[Q(s_{t+1},a_{t+1})]とし、次の状態がなく、時刻tで終了したときは教師信号となる報酬はr(t)として、学習しています。

ここでは学習する内容を保存しておく記憶クラスを定義しています。

addが追加、sampleはランダムに引数分だけ保存内容を取り出します。

lenは現在保存されている量を返します。

最後のクラスは状態s(t)に応じて、Cartを右に押すか左に押すかを決めるActorのクラスを定義しています。

ここではε-greedy法を採用し、最初は探索をするが、徐々に最適行動のみを行うようにしています。

ここからメイン関数がはじまります。

最初は定数を定義します。

DQN_MODEを1にするとDQN、0にするとDDQNとなります。

作成した3つのクラスからインスタンスを生成します。

Qネットワークについては、メインで使用し、価値関数を計算したり、行動を決定するためのmainQNと、max[Q(s_{t+1},a_{t+1})] を計算するtargetQNを分離しています。

またネットワークの状態を画像で出力するには、plot_modelを使用します。

ただし、他のライブラリをインストールしておく必要があります。

でgraphvizとpydotを入れておいてください。

次にメインの部分に入ります。

実際のメインルーチンが始まります。

for文のネストになっており、試行数のfor文のなかに、時間stepのfor文があります。

行っていることは、

1. actorから最適な行動を求める

2. 実際にその行動を行って、s(t+1)とr(t)を求める

※s(t+1)はKerasで扱えるようにlist型のstateを、1行4列の行列に変換

3. if doneは、終了状態になった場合は、s(t+1)にゼロ行列を入れておき、

4. 報酬r(t)を決める

今回は立っているときは報酬0、195ステップ以上立って終了したら報酬1、それ以前に倒れたら報酬-1を与えています。

5. s(t), a(t), r(t), s(t+1)をメモリに追加します。

6. Q学習を行います。

メモリ内容がバッチサイズより大きくなったステップから学習を開始します。

DQNの場合は学習後に価値観数を返すmainQNと、行動を決めるtargetQNを同じにします。

DDQNの場合は、各試行の最初だけ価値観数を返すmainQNと、行動を決めるtargetQNを同じにしています。

7. 最後は、各ステップの状況の出力と、各試行終わりの平均報酬から学習終了を判断しています。

以上で、DQN、DDQNが実装されました。

どちらも100試行程度で、学習が収束します。

DDQNの方が早い気がします。

また学習後の挙動の例は、次のような感じです。

DQN

DDQN

その他、いろいろとやっていて感じたのは、なんかもうちょっとで学習うまくいきそうで、失敗するときは、ネットワークの学習係数を下げるとうまくいくことが多いです。

Huber関数と、DDQNはこれで良いのかちょっと不安ですが、棒が立ったので、嬉しいです

本ページが参考になれば幸いです。

最後にコードをまとめて掲載します。

以上、DQN、DDQNでCartPoleを解く強化学習の実装でした。