SEAL: Scaling to Emphasize Attention for Long-Context Retrieval

生成日:

SEAL: Scaling to Emphasize Attention for Long-Context Retrieval

論文の面白いところ

この論文の特色は、長文処理の失敗を単に「文脈窓が足りない」と見ない点にある。対象のモデルは、形式上は 16K、32K、あるいは 128K トークンを扱える。それでも、長い入力の中から特定の数値や文を取り出すだけの課題で、文脈長が伸びるにつれて誤答が増える。著者らは、この失敗がモデルの知識不足ではなく、注意機構の偏りに関わると見る。実験では、注意ヘッドを一つずつ取り除くと、精度が下がるものだけでなく、逆に精度を上げるものもあることが示される。この観察は素朴だが、モデル内部に長文検索を助ける成分と妨げる成分が混在することを示す。SEAL はその成分を全面的に再学習するのではなく、強めるか弱めるかを学ぶ。しかも必要なデータは、実世界の大規模コーパスではなく、課題の形式に合わせた 50 件程度の合成例でよい。長文モデルを作り直すよりも、既存モデルの注意の利き方を整えるという、実務上扱いやすい方向を示している。

問題設定

長文対応の LLM は、長い文書読解、コード生成、複数ターンの対話などで用いられる。しかし、指定された文脈窓の内側であっても、入力が長くなるほど出力品質が落ちることがある。論文が主に扱うのは、長い入力の中に明示的に書かれた情報を取り出す長文検索である。たとえば LongEval の line retrieval では、多数の行にキーと数値が並び、質問で指定されたキーに対応する数値を返す。Needle-in-a-Haystack では、長い文章の途中に短い事実文を挿入し、その内容を質問する。RULER では、変数追跡、共通語の抽出、高頻度語の抽出など、規則に従って情報を拾う課題を扱う。これらは、モデルが外部知識を思い出す課題ではない。答えは入力中にあり、必要なのはその位置に届く検索能力である。したがって、誤答が生じるなら、長い文脈内で注意が適切に働いていない可能性がある。論文は、この検索能力を既存モデルに少ない調整で補うことを目的とする。

提案手法

SEAL は、注意出力にかけるスケールを学習し、長文検索に役立つ注意成分を相対的に強める方法である。著者らはまず、LongChat-7B-v1.5-32K を用いて、注意ヘッドを一つずつ剪定し、line retrieval の精度変化を調べた。あるヘッドを消すと精度が落ち、別のヘッドを消すと精度が上がるという結果が得られた。さらに、ヘッド単位だけでなく、ヘッド内部のチャンネル単位でも影響の差が見られた。これに基づき、SEAL-H は各注意ヘッドに一つの学習可能なスカラーを置き、SEAL-C は注意出力のチャンネルごとにスケールを置く。学習には下流課題と同じ形式の合成データを使う。line retrieval ならランダムな行名と数値を並べ、Needle-in-a-Haystack ならランダムな事実文と質問を作る。重要なのは、合成データの意味内容ではなく、入力と出力の形式である。学習後のスケールは、Llama 系の v_proj や o_proj の重みに事前に掛け込める。そのため、推論時には追加のモジュールや計算を持ち込まなくてよい。

結果

line retrieval では、LongChat-7B-v1.5-32K の 31K トークン入力における精度が、ベースラインの 0.32 から SEAL-H で 0.80、SEAL-C で 0.88 へ上がった。Vicuna-13B-v1.5-16K でも、長い入力で 0.42 だった精度が両方式で 0.94 まで上がった。Mistral-7B-Instruct-v0.2 はもともと大きな低下が少ないが、SEAL 適用後は多くの長さでほぼ満点に近い値を示した。Needle-in-a-Haystack でも、50 件の合成例だけで、挿入位置や入力長にまたがる検索精度が改善した。RULER では、Llama-3.1-8B-Instruct、Mistral-7B-Instruct-v0.2、LongChat-7B-v1.5-32K を対象に、変数追跡、共通語抽出、高頻度語抽出を評価している。とくに Llama-3.1-8B-Instruct の common word extraction では、64K 入力でベースラインが 0.1 まで落ちる一方、SEAL-C は 95.7 を保った。LongBench の文書質問応答でも、平均値として改善が報告されている。SEAL-H は LongChat-7B 全体で 1,024 個の学習パラメータだけを使い、LoRA よりはるかに小さい学習空間で同程度の line retrieval 性能を得た。MMLU の値はほぼ変わらず、一般的な知識問題の能力を大きく崩していないことも示されている。制約として、より長い訓練例を用いる場合は GPU メモリが増え、学習率の調整も多少必要になる。

具体例

たとえば、入力に数万トークンのログ風テキストがあり、その中に line verdant-efficiency: REGISTER_CONTENT is <24819> という行が一度だけ現れるとする。質問は「line verdant-efficiency の REGISTER_CONTENT は何か」である。モデルは長い入力全体を読んだうえで、指定された行名を探し、その直後にある数値を答えなければならない。期待される出力は「verdant-efficiency の REGISTER_CONTENT は 24819 である」という形になる。短い入力では、多くのモデルがこの処理を問題なく行える。しかし入力が 30K トークン前後まで伸びると、似た名前の行や近くにある別の数値に注意が寄り、たとえば 24856 のような誤った数値を返すことがある。SEAL は、このような形式の合成例を少数作り、正しい行と正しい数値を結びつける際に有用な注意成分を強める。反対に、誤った候補を押し上げやすい成分は弱まる。処理の内容は新しい知識を覚えることではなく、長い列の中で指定された場所へ注意を届かせることである。したがって、この手法は、契約書、ログ、長い会話履歴のように、答えが文中に明示されているが位置が遠い場面に向く。