はじめに
こんにちは、AIチームの杉山です。
今回の記事では、QA検索などを行う際の文類似度の計算に、文の埋め込みベクトルを用いてknnで計算した場合とSVMを用いた場合の結果を簡易的ですが定量的に比較してみたいと思います。
動機としては、LangChainのRetrieverの実装にkNNだけでなくSVMを用いた実装が採用されており、その説明の中で以下のようにSVMの方が良いことが多いとされていたことでどれくらい性能が異なるかを確認したいと思ったためです。[1][2]
TLDR in my experience it ~always works better to use an SVM instead of kNN, if you can afford the slight computational hit
具体的には、クエリ自身だけを1クラス、検索対象のドキュメント集合全体を別の1クラスとしてSVMで2クラス分類モデルを学習し、そのdecision_functionで得られるスコアを検索スコアとして使用します。クエリに対して毎回SVMのモデルを学習する必要があるのが難点ですが、これだけで検索精度が向上するのであれば魅力的な手法と考えられます。
検証方針
QA検索の場合、入力に対して事前に用意しているQuestion集合の中から類似したものを選出します。
実際のケースを想定すると、ユーザーからの入力クエリに対応するQuestionが一意に定まるとは限らず、検索スコアの上位N件をサジェストすることが一般的です。
類似度計算手法の評価にはユーザーの入力クエリq_nに対してQuestion集合Q_Mにそれぞれ類似度スコアが設定されていて、MRRなどで比較できると理想的ですがそういった都合の良い日本語のデータは見つけられません。
そこで今回は、その設定を簡略化して日本語意味的類似度計算タスクデータセットのJSTS[3]を用いて、JSTSで設定されている2文間の類似度スコアとkNN/SVMによる類似度スコアの相関係数の比較を行います。
データセット
検証には先述の通りJSTSを用います。データセットはHuggingFaceより以下のように取得します。
なお、件数は実際のQAプロダクトでのQuestion集合の規模の想定や、後段の文埋め込みやSVMの学習のコストを考慮して100, 500, 1000件で実験します。
from datasets import load_dataset
import pandas as pd
dataset = load_dataset("shunk031/JGLUE", 'JSTS', split='train[:N]') # N=100, 500, 1000
df = pd.DataFrame(dataset)
取得したデータはIDを除くと以下のような構造となっており、2文間の類似度がlabelとして0~5の値で付与されています。(値が大きいほど類似している)
sentence1 | sentence2 | label |
川べりでサーフボードを持った人たちがいます。 | トイレの壁に黒いタオルがかけられています。 | 0.0 |
二人の男性がジャンボジェット機を見ています。 | 2人の男性が、白い飛行機を眺めています。 | 3.8 |
今回の検証ではsentence1をクエリq_n、sentence2をQuestion集合Q_Mと見做して、sentence1の各q_nに対して対応するsentence2(Q_n)とのkNN/SVMでの類似度計算結果とlabelの相関係数を比較します。
文埋め込みの獲得
文埋め込みの獲得には、OpenAIの埋め込みモデルから text-embedding-ada-002 を用います。
import openai
openai.api_key = "YOUR_OPENAI_KEY"
def get_ada_embedding(text: str) -> List:
response = openai.Embedding.create(
model='text-embedding-ada-002',
input=text
)
return response.data[0].embedding
df['emb1'] = df.sentence1.map(get_ada_embedding)
df['emb2'] = df.sentence2.map(get_ada_embedding)
類似度計算
獲得した埋め込みベクトルに対し、kNNとSVMでの類似度計算を行います。それぞれの実装は今回の記事の元になった[4]に倣い、以下に疑似コードを示します。
kNN
import numpy as np
from numpy import dot
for q_n in q_N:
similarities = Q_M.emb.dot(q_n.dot)
sorted_ix = np.argsort(-similarities)
# JSTSのlabelと比較するためにq_nに対応するQ_nの類似度を取得.
similarities[Q_n]
SVM
from sklearn import svm
for q_n in q_N:
x = [q_n.emb, [Q_M.emb]] # クエリとQuestion集合それぞれの埋め込み
y = np.zeros(len(x))
y[0] = 1 # クエリのみラベルを1に設定
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1)
clf.fit(x, y)
similarities = clf.decision_function(x)
sorted_ix = np.argsort(-similarities)
# JSTSのlabelと比較するためにq_nに対応するQ_nの類似度を取得.
similarities[Q_n]
結果
JSTSのlabelとknn/SVMの類似度との相関係数は、各データ数Nに対して以下の値となりました。labelとそれぞれの類似度はスケールが異なりますが、どちらも似ているペアには大きい値が付与されることから正の相関が強いほど今回算出した類似度がlabelの類似度と似た傾向であると考えられます。
label - knn_sim | label - svm_sim | (参考)knn_sim - svm_sim | |
N=100 | 0.85285017 | 0.86324322 | 0.98992115 |
N=500 | 0.80645642 | 0.79177172 | 0.98248548 |
N=1000 | 0.8066241 | 0.78108655 | 0.97888241 |
結果として、N=100では若干SVMの方が高かったものの、Nを大きくするとkNNの方が高い値を示すという結果となりました。また、N=100の際も全体から選出する100件の取り方を変えると結果が逆転したりと全体的にSVMの手法は今ひとつな結果となりました。実際のQA検索としてのユースケースを考えると、クエリに対して毎回SVMのモデルを学習する必要もあり、大きなメリットは感じられませんでした。
おわりに
元になったNotebookでは、SVMがなぜうまくいくのか、について以下のような説明されていました。
In simple terms, because SVM considers the entire cloud of data as it optimizes for the hyperplane that "pulls apart" your positives from negatives. In comparison, the kNN approach doesn't consider the global manifold structure of your entire dataset and "values" every dimension equally. The SVM basically finds the way that your positive example is unique in the dataset, and then only considers its unique qualities when ranking all the other examples.
一方で、社内で話していた時にはLinearSVCでクラス分類しているので実質的にはembedding同士の内積が大きいものを選ぶことになり、内積(絶対値が全部同じならコサイン類似度)基準でkNNを行うのとあまり変わらないのでは、という考えもいただいており、kNNとSVMの類似度の相関係数がかなり高いことからもその可能性が高いのではと考えられます。
今回の検証では文埋め込みをBERTなど別の手法で試したり、RBFカーネルなどの線型SVM以外を用いたりといったことは行いませんでしたが、そのあたりの試行錯誤や日本語以外でのデータでの検証は行なってみたいと思います。
ここまで読んでいただきありがとうございました。
参考
[1] https://secon.dev/entry/2023/04/29/220000-langchain-svm-retriver/
[2] https://python.langchain.com/docs/modules/data_connection/retrievers/integrations/svm
[3] https://www.anlp.jp/proceedings/annual_meeting/2022/pdf_dir/E8-4.pdf
[4]https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb