trlxを用いた文書生成モデルの学習①~ILQL編~

こんにちは
AIチームの戸田です

今回は最近話題のChatGPTの学習に使われているRLHF(Reinforcement Learning from Human Feedback)を行うことができる強化学習フレームワーク、trlxを使った文章生成を試してみたいと思います。

trlxは強化学習手法としてILQL(Implicit Language Q-Learning)PPO(Proximal Policy Optimization)の2種類が用意されており、それに加えて通常の言語モデルの学習であるSFT(Supervised Fine-Tuning)も実装されています。

本記事では日本語感情分析データセットWRIMEのデータでILQLを使った学習を行い、PPOは次回の記事で試したいと思います。trlxライブラリを一通り動かすことを目的とし、パラメータ調整やデータクレンジングなどのより良い生成を行うための工夫は本記事では行いません。

なお、2023年2月28日現在trlxはv0.5が最新ですが、本記事はv0.4で試した際のコードになります。
[2023.03.01 追記]
執筆していた時期がv0.4とv0.5リリースの間で、tag付されていない2023年2月22日のcommit時点のコードで実験していたので、本記事と同じ実験環境を再現するには、このcommitにcheckoutした上でtrlxのinstallをしていただければと思います(ご指摘いただいた@moriokaさん、ありがとうございました!)

RLHF

RLHFはOpenAIの出したFine-Tuning Language Models from Human Preferencesという論文で提案されたGPTなどの言語モデルを人間のフィードバックを利用してFine-Tuningする手法です。Wikipediaなどの一般的な文章で学習した言語モデルを人間の好みに合わせた出力ができるようになると言われています。人間のフィードバックは多くの場合微分できないので、PPOなどの強化学習を使用します。

下の図はOpenAIがRLHFでredditの要約を行った際の論文の図です。

RLHFのフロー図
RLHFのフロー図

実際はこの前に❶で使う言語モデルの学習(SFT)が入りますが、RLHFの流れを分かりやすくまとめてくれている図だと思ったので転記させていただきました。

trlxで何ができるかを紹介するために、まずは大まかにRLHFの流れを説明します。(数式などより詳細が知りたい方はOpenAIの元論文をご参照ください。)

Step 0: SFT

prompt(入力文)とそれに対する適切なoutput(出力文)のペアを用いて言語モデルの学習を行います。
従来の言語生成モデルの多くはこの学習を行っていましたが、学習時のpromptの表現に大きく寄ってしまったり、文章としては正しいけど人間の好みに合わなかったりと多くの課題がありました。

Step 1: 人間のフィードバックを集める

Step 0で学習したモデルが出力したoutputを複数パターン用意し、人間にどれが良いのかをラベルづけ(=フィードバック)してもらいます。ChatGPTの前身のInstructGPTは順位づけによってどのoutputがよいのかをアノテーションしているようです。
図中❶ Collect human feedback

Step 2: reward modelの学習

Step 1でつけられたラベルをもとにreward model、つまりあるpromptに対するoutputがどれほど良いのかを評価する報酬モデルを学習させます。
単純な良し悪しの2値分類も考えられますし、InstructGPTのように回答の順位付けを予測するタスクを解かせる場合もあります。
図中❷ Train reward model

Step 3: 強化学習によるFine-Tuning

PPOなどの強化学習手法を使ってStep 0で学習した言語モデルをさらにFine-Tuningします。強化学習を行うための報酬はStep 2で学習されたreward modelの出力を使います。また、ここでは学習を安定させるために、報酬の正規化やKLダイバージェンスによる正則化項といったテクニックを組み合わせています。
図中❸ Train policy with PPO

trlxは主にStep 3の学習を行うフレームワークになります。報酬の正則化なども内部で実装されており、学習したい言語モデルやreward modelを指定することでStep 3の強化学習によるFine-Tuningを実行することができます。また、ChatGPTで行われているようなRLHFではないのですが、ILQLのようなオフライン強化学習を使った言語モデルのFine-Tuningも行うことができます。
trlxにはStep 0のSFTの学習を行えるAPIも提供されています。ただし学習率やoptimizerなどの基本的なパラメータ設定しか行うことができないので、凝った設定で学習したい場合は独自のパイプラインを構築した方が良いかもしれません。

なお、Step 2の人間のフィードバックを集めるためのツールとしてtrlxを提供しているCarperAIがCHEESEというライブラリを公開しています。興味のある方はご参照いただければと思います。

実験

強化学習を使った言語モデルのFine-Tuningには、どうしても人手評価したデータセットが必要になるので、試すのはなかなか難しいかと思っていたのですが、trlxの例を見ると、映画の感想の極性分類データセット(IMDB)を上手く利用していました。極性ラベルは人間がつけたポジティブかネガティブかのフィードバックといえ、それを学習した極性分類モデルはreward modelとして使えることが期待されます。これらを利用してポジティブな映画の感想を生成するように学習していました。

本記事ではまずreward modelを使わない、オフライン強化学習手法であるILQLを使って、強化学習による言語モデルのFine-Tuningを試してみたいと思います。また、2023年2月28日現在、私の調べた限りでは、日本語のデータセットでtrlxでの学習を試した例がなかったので、初の日本語データでのトライとして、日本語感情分析データセットWRIMEを使うことにしました。

WRIME

WRIMEは日本語の感情分析の研究のためのデータセットです。以下にREADMEにあるデータセットの特徴を転記します。

  • 主観(テキストの筆者1人)と客観(クラウドワーカ3人)の両方の立場から感情ラベルを付与しました。
  • Plutchikの基本8感情(喜び、悲しみ、期待、驚き、怒り、恐れ、嫌悪、信頼)を扱いました。
  • 各感情の強度を4段階(0:無、1:弱、2:中、3:強)でラベル付けしました。
  • Ver.2では、感情極性(-2:強いネガティブ、-1:ネガティブ、0:ニュートラル、1:ポジティブ、2:強いポジティブ)も追加しました

ちなみにWRIMEはHuggingFace Hubからも利用することができます。

今回はver1の喜び(joy)と悲しみ(sadness)をそれぞれ学習してみたいと思います。

実装

ここからは実際のコードと一緒に説明していきたいと思います。

インストール手順は環境によって異なると思いますが、私の環境(google colab T4)だと以下のような手順でインストールできました。

git clone https://github.com/CarperAI/trlx.git
cd trlx && git checkout 93c90cbdc3c6b463f565b09340ca1f74271285c5
git config --global --add safe.directory /content/trlx && cd /content/trlx && pip install -e .

pip uninstall -y scikit_learn jax # numpyのerrorが出てしまう
pip install sentencepiece  # マルチバイトtokenizerを使うので

以下は必要なライブラリのインストールです。

import os
os.chdir(‘/content/trlx’)

import yaml
from datasets import load_dataset
import pathlib
import trlx
from trlx.data.configs import TRLConfig

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

計算時間的にGPUは必須だと思います。batch sizeを調整すればgoogle colabでも動くのでGPUマシンを持っていない方も是非試してみてください。

configの上書き

学習の設定ファイルをyamlで記述してそれを読み込むのですが、今回は公式のサンプルを利用して、必要なところだけ上書きしようと思います。

model_name = 'rinna/japanese-gpt2-medium'

with open('configs/ilql_config.yml') as f:
    default_config = yaml.safe_load(f)

default_config['train']['tracker'] = None
default_config['train']['batch_size'] = 16

default_config['model']['model_path'] = model_name
default_config['tokenizer']['tokenizer_path'] = model_name

config = TRLConfig.update(default_config, {})

model_pathとtokenizer_pathはデフォルトでは英語版のGPTが設定されているので、日本語を扱えるrinna/japanese-gpt2-mediumに変更します。

batchサイズはデフォルト128のところ、私の使用しているGPU(T4)でも動くように小さく変更しています。パラメータとしてseq_lengthもあるので、環境によってはこちらを調整するのが良いかもしれません。

trackerはNoneを指定していますが、コードを確認したところ、wandbとtensorboadを指定できるようです。

データセットの準備

HuggingFace HubからWRIMEのデータセットを取得します。

wrime = load_dataset("shunk031/wrime", name="ver1")

train_sentence = wrime['train']['sentence']
valid_sentence = wrime['validation']['sentence']
sentence = train_sentence + valid_sentence

senti = 'joy'  # or 'sadness'
train_labels = [r[senti] for r in wrime['train']['avg_readers']]
valid_labels = [r[senti] for r in wrime['validation']['avg_readers']]
labels = train_labels + valid_labels
labels = [int(i != 0) for i in labels]

今回は特にラベルを使った評価は行わないので、trainデータとvalidationデータを混ぜてしまいます。joy(またはsadness)をラベルとし、感情強度は4段階ですが、1でもフラグがたっていれば1、そうでなければ0とします。(10.0, 100.0のような連続値もラベルとすることができますが、今回は簡略化のためにこうしています。)

評価用プロンプト

testデータ中のUserID 49の以下のデータから、学習中に使う評価用のプロンプトを作りました。

DatetimeSentence
2016/06/21 10:16今週にでもヒトカラ行こうかな
2016/07/02 11:38今日6時だと思ってたけど、4時からやんな…危ない
2016/07/03 20:58帰ったらSASUKE見るぞ?!
eval_prompts = [
    '今週末に',
    '今日6時',
    '帰ったら',
]

先頭4文字を取ってきています。

学習

以下のコードを実行することで学習が始まります。

trainer = trlx.train(
    samples=sentence,
    rewards=labels,
    eval_prompts=eval_prompts,
    config=config,
)

出力

最終epochの評価promptの出力結果例を以下に示します。

joy

PromptOutput
今週末にもう一回行ってみたいね。 私、まだ行きてええ... って思ってるよ。
今日6時すぎに寝て、 10時半前に起きて、7時に飯食べて、9時半に寝て、 11時半過ぎに起きて、10時半に飯食べて、 10-
帰ったらCC-C-で寝て、起きたらCのC-が来てて、 なんだこれ!と、思ってももうすぐ3歳になる子どもが「え〜!何

sadness

PromptOutput
今週末に、witterのアカウントをまた作り直そうと思うのですが、 また同じことをしてアカウントを2つ作れないのは、...
今日6時今日6時起きだ、頑張らなきゃいけない
帰ったら帰ったらやるべきことはやったので寝ます

文章として意味不明なものが多く、joyとsadnessで感情による差はみられないようです。
今回はStep 0のSFTのフェーズをスキップしてしまったので、思ったような出力が得られなかったのではないかと考えています。パラメータが初期設定のままだったことも原因の一つかもしれません。まだまだ改善の余地がありそうです。

おわりに

本記事ではtrlxに実装されているオフライン強化学習手法のILQLを使って言語モデルのFine-Tuningを試してみまた。

生成結果はお粗末なものでしたが、とりあえずtrlxライブラリの使い方はわかったので、次回PPOによるFine-Tuningを試してみたいと思います。

最後までお読み頂き、ありがとうございました!

PICK UP

TAG