OpenAI ChatGPTの出力の尤もらしさをトークンレベルで可視化する

はじめに

こんにちは、AIチームの杉山です。
LLMを組み込んだプロダクトは、その出力にハルシネーション(幻覚)が含まれる可能性を考慮し、後処理や処理フローを整えて正常系として動作することが必要です。ハルシネーションを検知する方法は様々提案されていますが[1][2][3]、今回の記事ではOpenAI社のChatGPTが出力するlogprobを用いてトークン単位での出力がどれくらい尤もらしいかを可視化したいと思います。

ChatGPT Log Probability

ChatGPTのAPIでは、リクエストにlogprobs=Trueを指定することでトークン単位でのLog Probabilityを取得することができます。[4] Log Probabilityは、コンテキストが与えられたときに各トークンがシーケンスに出現する対数確率で、本記事では出力におけるそのトークンの尤もらしさとして扱います。

from openai import OpenAI
client = OpenAI(api_key="YOUR_API_KEY")

completion = client.chat.completions.create(
  model="gpt-4-1106-preview",
  messages=[
    {"role": "user", "content": "世界で一番背の高い人の身長は何センチですか?"}
  ],
  logprobs=True,
  max_token=100,
  temperature=0,
)

print(completion.choices[0].message.content)
# '2023年の時点で、世界で一番背の高い人としてギネス世界記録に認定されているのはトルコ出身のスルタン・ケーセン(Sultan Kösen)です。
# 彼の身長は公式には251センチメートル(8フィート2.8インチ)と記録'

print(completion.choices[0].logprobs)
# [ChatCompletionTokenLogprob(token='202', bytes=[50, 48, 50], logprob=-0.8971278, top_logprobs=[TopLogprob(token='202', bytes=[50, 48, 50], logprob=-0.8971278)]),
# ChatCompletionTokenLogprob(token='3', bytes=[51], logprob=-0.021846717, top_logprobs=[TopLogprob(token='3', bytes=[51], logprob=-0.021846717)]),
# ChatCompletionTokenLogprob(token='年', bytes=[229, 185, 180], logprob=-6.337155e-05, top_logprobs=[TopLogprob(token='年', bytes=[229, 185, 180], logprob=-6.337155e-05)]),
# ChatCompletionTokenLogprob(token='の', bytes=[227, 129, 174], logprob=-1.0081867, top_logprobs=[TopLogprob(token='の', bytes=[227, 129, 174], logprob=-1.0081867)]),
# ChatCompletionTokenLogprob(token='時', bytes=[230, 153, 130], logprob=-0.07075444, top_logprobs=[TopLogprob(token='時', bytes=[230, 153, 130], logprob=-0.07075444)]),
# ChatCompletionTokenLogprob(token='点', bytes=[231, 130, 185], logprob=-9.0883464e-07, top_logprobs=[TopLogprob(token='点', bytes=[231, 130, 185], logprob=-9.0883464e-07)]),
# ChatCompletionTokenLogprob(token='で', bytes=[227, 129, 167], logprob=-0.07530445, top_logprobs=[TopLogprob(token='で', bytes=[227, 129, 167], logprob=-0.07530445)]),
# ChatCompletionTokenLogprob(token='、', bytes=[227, 128, 129], logprob=-0.5530573, top_logprobs=[TopLogprob(token='、', bytes=[227, 128, 129], logprob=-0.5530573)]),
# ChatCompletionTokenLogprob(token='\\xe4\\xb8', bytes=[228, 184], logprob=-0.569853, top_logprobs=[TopLogprob(token='\\xe4\\xb8', bytes=[228, 184], logprob=-0.569853)]),
# ChatCompletionTokenLogprob(token='\\x96', bytes=[150], logprob=-1.1637165e-05, top_logprobs=[TopLogprob(token='\\x96', bytes=[150], logprob=-1.1637165e-05)]),
# ChatCompletionTokenLogprob(token='界', bytes=[231, 149, 140], logprob=-0.0007354162, top_logprobs=[TopLogprob(token='界', bytes=[231, 149, 140], logprob=-0.0007354162)]),
# ChatCompletionTokenLogprob(token='で', bytes=[227, 129, 167], logprob=-0.0027969147, top_logprobs=[TopLogprob(token='で', bytes=[227, 129, 167], logprob=-0.0027969147)]),
# ... ]

Webで検索すると、どうやら回答自体は正しそうです。

このままではどの箇所がlogprobが小さい(確信度が低い)のか分かりづらいため、折れ線グラフで表示してみます。

import matplotlib.pyplot as plt

indexes = range(len(tokens))
# 日本語を表示するためのフォント指定
plt.rcParams['font.family'] = 'YOUR_FONT'
plt.figure(figsize=(30, 3))
# 折れ線グラフの作成。インデックスを横軸として使用
plt.plot(indexes, logprobs, marker='o')


# 横軸(x軸)に目盛りラベルを設定
# 表示の際に文字が被って見えなくなるため縦表示
plt.xticks(indexes, tokens, rotation="vertical")


# 軸のラベルおよびグラフのタイトルを設定
plt.xlabel('Token')
plt.ylabel('Log Probability')
plt.title('Line Graph of Probabilities for Each Token')

# グラフを表示
plt.show()

グラフとして表示することはできましたが、一部のトークンがByte-fallbackしており日本語として何であるかを視認することができないため、トークンをマルチバイト文字として表示できる様にまとめる処理を行います。この時、複数のトークンで1つの文字となるためその文字に対するLog Probabilityは各トークンのlogprobの平均値を取ることにして、GPT-4にコードを作成してもらいました。

chars = []
char_logprobs = []

buffer = ''
logprob_sum = 0.0
logprob_count = 0

def decode_buffer():
    global buffer, logprob_sum, logprob_count, chars, char_logprobs
    # Decode the buffered hex string and add to tokens
    try:
        char = bytes.fromhex(buffer).decode('utf-8')
    except Exception as e:
        raise e
    chars.append(char)
    # Calculate the average logprob for these sequences and add to logprobs
    char_logprobs.append(logprob_sum / logprob_count)
    # Reset buffer and related variables
    buffer = ''
    logprob_sum = 0.0
    logprob_count = 0

for token, logprob in zip(tokens, logprobs):

    # If token is Japanese character, output immediately
    if '\\' not in token:
        if buffer:
            decode_buffer()
        chars.append(token)
        char_logprobs.append(logprob)
    else:
        # Buffer hex values without '\\x'
        hex_value = token.replace("\\x", "")
        buffer += hex_value
        logprob_sum += logprob
        logprob_count += 1

        # If buffer has enough data for one character, decode it
        if len(buffer) >= 6:  # 6 hex digits needed for one UTF-8 character
            try:
                decode_buffer()
            except Exception as e:
                print(e)

上記の処理を行った後、再度描画してみます。

先ほどByte-fallbackしていたトークンがまとめられて日本語として表示できていることが確認できます。

このまま折れ線グラフでも確認できますが、値とトークンの対応が見えづらいため、テキストを並べてその背景にlogprobをカラースケールで表示することにします。

from matplotlib.colors import Normalize
from matplotlib.patches import Rectangle
import matplotlib.cm as cm

norm = Normalize(vmin=min(char_logprobs), vmax=0, clip=True)

# カラーマッパーを生成
mapper = cm.ScalarMappable(norm=norm, cmap=cm.Reds)

# トークンを10個ごとに改行するための設定
tokens_per_row = 50
num_rows = (len(chars) + tokens_per_row - 1) // tokens_per_row

# 描画するためのfigとaxを取得
fig, ax = plt.subplots(figsize=(20, 2)) 
ax.axis('off')

# テキストボックスのマージンを設定
margin_x = 0.1 / tokens_per_row  # X軸のマージン
margin_y = 0.1 / num_rows  # Y軸のマージン

# トークン毎にテキストボックスを描画
for i, (word, logp) in enumerate(zip(chars, char_logprobs)):
    # 正規化された値に基づいて色を取得(透明度も含む)
    color = mapper.to_rgba(logp)
    alpha = 1 - norm(logp)  # 透明度の調整
    color = color[:3] + (alpha,)  # RGBAフォーマットに修正

    # テキストボックスの位置を決定
    row = i // tokens_per_row
    col = i % tokens_per_row
    text_x = col * (1.0 / tokens_per_row) + margin_x / 2
    text_y = 1 - (row + 1) * (1.0 / num_rows) + margin_y / 2

    # テキストボックスのサイズを決定
    box_width = 1.0 / tokens_per_row - margin_x
    box_height = 1.0 / num_rows - margin_y

    # テキストボックスを描画
    ax.add_patch(Rectangle((text_x, text_y), box_width, box_height, color=color, transform=ax.transAxes))
    ax.text(text_x + box_width / 2, text_y + box_height / 2, word, ha='center', va='center', color='black',
            transform=ax.transAxes)

logprobの値が小さいトークンの背景を濃くすることでどのあたりで確信度が低くなっているか分かりやすくなりました。今回の例では、出力全体としては正しかったものの、トークンレベルではその確信度に濃淡があることが確認できます。

終わりに

今回の記事では、ChatGPTのLog Probabilityを用いてトークン単位での出力確信度を可視化する方法を紹介しました。

出力の傾向として固有名詞(の先頭トークン)や数値の確信度が低くなることが多かったのですが、同じ入力でも結果が変わったりと不確実なことも多く、例えばRAG(Retrieval-Augmented Generation)のGeneratorの出力確信度をユーザーに表示して情報を取捨選択してもらうなど、特定の用途に使うにはもう少し分析と検証が必要だと感じました。例えば、最近seedパラメーターが追加されてこの値を固定することで決定的ではないですが再現性を高めることができるようになったので[5]、そちらも試してみたいと思います。

AI ShiftではこのようにLLMを組み込んだプロダクトの改善のための技術検証を行なっています。この分野でチャレンジしたい、興味がある方はぜひお気軽にお声がけください!

ここまで読んでいただきありがとうございました。

参考

[1] https://aclanthology.org/2023.emnlp-main.155/
[2] https://aclanthology.org/2023.acl-long.910/
[3] https://aclanthology.org/2023.eacl-main.75/
[4] https://platform.openai.com/docs/api-reference/chat
[5] https://platform.openai.com/docs/guides/text-generation/reproducible-outputs

PICK UP

TAG