Tracr-Injection: Distilling Algorithms into Pre-trained Language Models

生成日:

Tracr-Injection: Distilling Algorithms into Pre-trained Language Models

Abstract(日本語訳)

大規模言語モデルの急速な普及を背景に、Transformer アーキテクチャが本来備える記号的能力を形式的に特徴づけようとする動きが進んでいる。RASP と呼ばれるプログラミング言語は、こうしたアルゴリズムを実装する Transformer の重みに直接コンパイルできるものとして提案されている。しかし、RASP で実装できるタスクは、自然な教師なしデータから学習するにはまれな場合が多く、Transformer アーキテクチャの理論上の能力と、その能力を教師なしデータから実際に学習できるかどうかとの間にはずれがある。本論文では、RASP で書かれたアルゴリズムを事前学習済み言語モデルへ直接蒸留できる方法である tracr-injection を提案する。この方法の例として、3 種類のアルゴリズムを言語モデルに注入する。提案手法は、モデルの residual stream 内に解釈可能な部分空間を作り、この部分空間を RASP アルゴリズムのコード中に現れる変数へデコードできることを示す。さらに、提案手法はベースラインと比べて out-of-distribution 性能を改善できることが分かり、モデル内部でより記号的な機構が働いていることを示唆している。実験に用いたコードは公開されている。

論文の面白いところ

この論文の焦点は、言語モデルに記号処理を「外から使わせる」のではなく、パラメータの中へ入れようとする点にある。既存研究では、RASP でコンパイルしたモデルに処理を委ねる方法があるが、それでは外部コードを実行する場合との差が曖昧になりやすい。本論文は、コンパイル済み Transformer の residual stream を GPT-2-large の residual stream に対応させることで、既知のアルゴリズムをモデル内部の表現として持たせようとする。ここで重要なのは、単に正答率を上げるだけではない。RASP プログラムには、括弧の数、桁上がり、位置などの中間変数があるため、蒸留後のモデルからそれらに対応する表現を読み出せる。通常のファインチューニングでは、モデルがどの手掛かりで答えたかは後から推測するしかない。本手法では、少なくとも実験した範囲では、注入した部分空間にノイズを加えると性能がより大きく落ちるため、その表現が予測に因果的に関与していると考えられる。解釈可能性研究で不足しがちな「正解の回路」に近い参照点を人工的に作れるところが、この研究のいちばん使い道のある部分である。

問題設定

Transformer は理論上、一定の記号的アルゴリズムを表現できることが知られている。RASP はその性質を扱うための言語であり、tracr は RASP プログラムを実際の Transformer 重みにコンパイルするライブラリである。問題は、表現できることと、通常の事前学習データから学べることが同じではない点にある。たとえば、括弧列が正しく釣り合っているかを判定する Dyck 系の言語や、桁上がりを伴う加算は、自然文の次トークン予測だけから安定して学ばれるとは限らない。単純なファインチューニングで正答率を上げることはできても、モデルが本当にアルゴリズムを実装したのか、訓練分布の表面的な規則を覚えたのかは判別しにくい。さらに、機械論的解釈可能性の研究では、抽出された circuit が本当に正しい機構かを評価する基準が不足している。本論文は、既知のアルゴリズムを事前学習済みモデルに注入し、その中間変数を読めるようにすることで、この二つの問題を同時に扱う。対象は実用タスクそのものではなく、記号処理を持つモデルを作り、その内部を検査するための実験的な枠組みである。

提案手法

提案手法 tracr-injection は、まず目的のアルゴリズムを RASP で書き、tracr により causal Transformer としてコンパイルする。このコンパイル済みモデルは、対象タスクを正しく解く小さな教師モデルとして働く。次に、GPT-2-large をタスクデータでファインチューニングしながら、各層の residual stream を線形写像でコンパイル済みモデルの residual stream に近づける。損失は三つから成る。第一はタスクの正解を学ぶ cross-entropy loss、第二は両モデルの内部表現を cosine loss で合わせる algorithm loss、第三は FineWeb の教師なしテキストに対する KL divergence loss である。KL loss は、アルゴリズムを入れる過程で事前学習済みモデルの一般的な言語能力が大きく崩れないようにするために置かれている。さらに著者らは、注入した記号的部分空間をモデルが実際に使うように、最終 residual stream の直交成分をバッチ内の別例から入れ替える介入を加える。これにより、モデルはタスク解決に必要な情報を、RASP 由来の線形部分空間に載せざるを得なくなる。実験対象のアルゴリズムは、Shuffle-Dyck の判定、文字列中の x の数のカウント、二つの整数の加算である。

結果

主実験では、GPT-2-large に三つのアルゴリズムをそれぞれ注入し、通常のファインチューニングに KL loss を加えたベースラインと比較している。in-distribution のタスク精度は、Shuffle-Dyck で 99.9%、整数加算で 95.0%、カウントで 99.2% であった。ベースラインはそれぞれ 100.0%、99.7%、99.8% であり、単純な正答率だけを見るとベースラインの方がわずかに高い。Tiny-Shakespeare 上の perplexity や SST-2、MRPC、LAMBADA の結果を見る限り、提案手法による大きな catastrophic forgetting は観察されていない。重要なのは、out-of-distribution 評価での差である。Shuffle-Dyck では、ほぼ釣り合った誤例、長さを 3 倍にした例、未知の括弧種を加えた例のいずれでも、tracr-injection がベースラインより高い精度を示した。一方、整数加算やカウントでは結果が一貫せず、4 桁加算や長いカウントでは両手法とも弱かった。つまり、提案手法は常に完全なアルゴリズム注入を実現するわけではないが、RASP で与えた構造が汎化に効く場合がある。著者ら自身も、より複雑なタスクやより大きなモデルへの拡張は今後の課題としている。

具体例

入力が Are parentheses here correctly matched? [(]). Answer: のような括弧判定だとする。通常の言語モデルなら、訓練で見た括弧列に似ているか、あるいは局所的なパターンに頼って Yes / No を出してしまう可能性がある。tracr-injection では、まず RASP で書かれた Shuffle-Dyck 判定の手続きが、開き括弧と閉じ括弧の個数や位置に関する中間変数を作る。tracr でコンパイルされた教師モデルは、その変数を residual stream の特定の次元に持つ。GPT-2-large の学習時には、同じ入力に対して GPT-2 側の residual stream も線形写像を通じてその変数表現に近づけられる。期待される出力は、括弧種ごとの数が釣り合っていれば Yes、釣り合っていなければ No である。間違えやすいのは、通常の Dyck 言語では順序も重要だが、Shuffle-Dyck では括弧種ごとのバランスが主な条件になる点である。たとえば [(]) は通常の括弧対応としては不自然に見えるが、Shuffle-Dyck では []() の数が合っていれば正例として扱われる。このような規則を、出力ラベルだけでなく中間変数としてモデル内部に持たせることが、本論文の中心的な考えである。