Implicit Reasoning in Transformers is Reasoning through Shortcuts
- Transformer の暗黙的推論を、多段の加減算課題で調べた論文である。
- 前提の順序が固定されていれば、GPT-2 型モデルは中間結果を内部で順に受け渡し、高い精度で解ける。
- 前提の順序が乱れると、モデルは変数を追跡するよりも数を連結する近道に寄り、変数が減数になる例で大きく失敗する。
論文の面白いところ
この論文は、言語モデルが「考えているように見える」場合に、内部で何をしているのかをかなり小さな実験系に落として調べている。対象は自然文の難問ではなく、複数段の加算と減算である。こうすることで、事前学習中に見た事実の記憶や言い換えの影響を避け、推論の型そのものを観察しやすくしている。著者らの主張は単純で、暗黙的推論はうまく見える場面でも、しばしば近道によって成り立つというものである。興味深いのは、モデルがまったく逐次計算できないわけではない点である。前提が計算順に並ぶ固定パターンでは、中間結果が層と位置をまたいで斜めに伝わるような挙動が観察される。つまり、内部に逐次処理らしい構造は現れうる。しかし前提の順序が変わると、その構造は安定して使われず、加算の交換法則に依存した数の連結に近い処理へ傾く。暗黙的な Chain-of-Thought を期待する研究に対して、どの条件なら内部推論と呼べるのかを問い直す材料を与えている。
問題設定
研究の中心は、明示的な途中式を出さない言語モデルが、多段推論を内部だけで実行できるかという問いである。明示的推論では、モデルは途中の計算や根拠をトークンとして出力する。暗黙的推論では、それらを出力せず、隠れ状態の中で処理したうえで答えだけを返す。暗黙的推論は生成トークン数が少なく、推論時の費用が小さいという利点をもつ。一方で、複雑な推論では明示的推論より弱いことが多い。本論文は、その弱さが単なる計算能力の不足なのか、あるいは学習された処理方略の問題なのかを調べる。課題には、剰余 23 の範囲で行う合成的な加減算系列を使う。各ステップは、前の変数と数値と演算子から次の変数を定める形をとる。訓練は 1 から 5 ステップ、評価は同じ長さの分布内評価と、6 または 7 ステップの分布外評価を含む。さらに、前提を計算順に並べる場合と、逆順またはランダム順にする場合を分けている。
提案手法
著者らは、12 層の GPT-2 を基礎とし、位置埋め込みを Rotary Position Embedding(RoPE)に置き換えたモデルを用いる。RoPE を使うのは、訓練時より長い系列へ能力が伸びるかを見やすくするためである。訓練データは、加算と減算からなる多段の式テンプレートを作り、変数名を変えて具体化することで生成される。テストでは、訓練データと中間計算が重なるテンプレートを除き、単なる暗記では解けないようにしている。解析には activation patching を用いる。これは、元の入力と一部を変えた入力を別々に通し、特定の層や位置の活性だけを差し替えて出力への影響を見る方法である。差し替えで正解トークンのロジットが大きく動く場所は、その計算に重要だったと解釈できる。著者らは、残差ストリーム、Attention、MLP を分けて調べる。さらに、注意範囲を現在のステップだけに制限した場合と、直前のステップまで見られる場合を比べ、中間結果を次段へ渡しているかを確認している。
結果
前提が計算順に固定されている場合、モデルは暗黙的な多段計算をかなりよく学習した。分布内の課題では精度が 100% に達し、訓練より 1 ステップ長い課題でも 99%、2 ステップ長い課題でも約 90% の精度を示した。activation patching では、各ステップの終端付近から次のステップへ、重要な情報が斜めに伝わるような分布が見られた。これは、各段の中間結果を内部で作り、次の段で再利用していることを示す。注意範囲を現在のステップに限ると推論能力は失われ、直前のステップまで見られるようにすると精度が回復した。Attention は中間結果の伝達に、MLP は入力や出力に関する特徴の強調に関わると解釈される。
しかし、前提の順序が固定されていない訓練では様子が変わる。5 ステップ課題の精度はおよそ 40% 程度まで落ちる。特に、変数が減算の右側、すなわち減数になる式が増えると精度は急激に低下する。著者らはこれを “Variable as Subtrahend Plight” と呼ぶ。たとえば b = a - 3 なら、数を順につなげて ... - 3 と扱う近道が働きやすい。ところが b = 3 - a では、変数の値を先に正しく計算してから引かなければならない。GPT-4o、Claude 3.5 Sonnet、Llama 3 70B Instruct、Qwen2.5 72B Instruct でも、3 ステップ課題で同様の低下が観察された。GPT-4o は、変数が減数にならない場合は高精度だが、該当する式が二つ含まれると約 30% まで下がった。
具体例
入力として、m = 16 - 5, z = 11 - m, b = z + 22, b = ? という三段の式を考える。正しく解くには、まず m = 11 を得て、次に z = 11 - 11 = 0 とし、最後に b = 0 + 22 = 22 と求める。ここで重要なのは、二段目の m が引かれる側ではなく、引く側に置かれていることである。近道を使うモデルは、式全体を 16 - 5 + 11 + 22 のように、数を見つけた順に符号付きで連結して扱いやすい。この近道は、変数が z = m - 11 のように現れる例では偶然うまくいくことがある。しかし z = 11 - m では、m の値を先に確定しなければ符号が決まらない。期待される出力は 22 であるが、近道に頼ると 44 のような誤答が出る。この例は、モデルが式の表面上の並びを使って答える場合と、変数の値を段階的に追跡する場合の差をよく示している。