LongReD: Mitigating Short-Text Degradation of Long-Context Large Language Models via Restoration Distillation

生成日:

LongReD: Mitigating Short-Text Degradation of Long-Context Large Language Models via Restoration Distillation

論文の面白いところ

この論文の中心は、文脈窓を広げることを単なる長文処理の問題として見ない点にある。多くの実装では、Rotary Position Embedding(RoPE)の設定を変え、長い系列で軽く継続事前学習を行う。これにより 32K、128K、さらに長い文脈を扱えるようになるが、MMLU や HumanEval のような短いタスクで性能が下がる場合がある。著者らは、この現象を経験則として片づけず、モデル内部の分布の変化として測定する。元のモデルと拡張後のモデルを同じ短文に適用し、隠れ状態のコサイン類似度と注意分布の Kullback-Leibler(KL)ダイバージェンスを見る。継続事前学習によって差は小さくなるが、完全には戻らない。さらに、学習の初期には短文性能が回復しても、学習を続けると再び落ちることを示す。このため、問題は位置埋め込みの変更だけでなく、長文データに適応する過程で元の能力を忘れることにもある。実用上は、長文対応モデルを作るときに、短い問い合わせやコード補完が犠牲になっていないかを確認する必要がある、という素朴だが重要な示唆を与えている。

問題設定

対象は、大規模言語モデル(Large Language Model, LLM)の文脈窓拡張である。一般的なデコーダ型 Transformer は、事前学習時の長さや位置符号化に強く制約される。入力が想定より長くなると、位置情報が学習時と異なる領域に入り、性能が不安定になる。そこで、RoPE の基底を変えたり、位置を補間したりして、より長い系列を処理できるようにする手法が使われる。だが、長い系列に合わせて継続事前学習を行うと、短い入力に対する成績が落ちることがある。本論文は、この短文性能の低下を説明し、それを抑えながら文脈窓を拡張する方法を考える。評価では Llama-3-8B と Mistral-7B-v0.3 を用い、短文側では一般知識、コード、数学、読解、常識質問応答を含む 17 種のベンチマークを見る。長文側では RULER を用い、長い文脈の中から必要な情報を使えるかを調べる。したがって、目標は長文だけで高い点を取ることではなく、元の短文能力をなるべく保ったまま文脈窓を広げることである。

提案手法

著者らの方法は Long Context Pre-training with Restoration Distillation(LongReD)と呼ばれる。LongReD は、一つの学習で三つの目的を合わせる。第一は通常の長文学習であり、拡張後のモデルに長い系列を読ませ、次トークン予測で長距離依存を学ばせる。第二は短文蒸留であり、元のモデルを教師、拡張後のモデルを生徒として、短い入力に対する選択された層の隠れ状態を近づける。蒸留する層は一律には選ばず、元モデルと拡張モデルの注意分布の KL ダイバージェンスが大きい層を優先し、最後の層も含める。すべての層を縛ると長文への適応を妨げるため、ずれの大きい少数の層に絞る設計である。第三は short-to-long distillation である。ここでは短いテキストを使いながら、途中の位置番号を飛ばして、長い文脈中の離れた位置に置かれたかのように扱う。元のモデルは通常の位置番号で処理し、拡張後のモデルは飛ばした位置番号で処理し、最後の層の出力を近づける。この処理により、短文で保った能力を長い位置範囲にも移すことを狙う。最終的な損失は、長文の言語モデル損失、短文蒸留損失、short-to-long 蒸留損失の和で構成される。

結果

分析では、文脈窓を広げたモデルと元のモデルの内部表現にずれが残ることが示された。Llama-3-8B を 32K や 128K に拡張した場合、継続事前学習後も隠れ状態の類似度は完全には 1 に戻らず、注意分布にも差が残る。さらに、隠れ状態の類似度が高い拡張モデルほど、MMLU の性能保持率も高い傾向があった。学習ステップの実験では、短文性能は初期に回復するものの、その後の長文継続学習で低下する。短文データを混ぜるだけでも忘却は緩和されるが、LongReD はそれより一歩進めて、元モデルの内部表現を明示的に保存する。主結果では、Llama-3-8B を 32K に拡張した場合、単純な長文継続事前学習の短文平均は 51.00 であったのに対し、LongReD-C は 54.85 であり、元モデルの 55.16 に近かった。著者らは、この設定で短文性能の保持率を 99.4% と報告している。128K への拡張でも、LongReD は単純な長文学習より短文平均を高く保った。Mistral-7B-v0.3 を 128K に拡張する実験でも、LongReD は短文平均と RULER の双方で通常の継続事前学習を上回った。アブレーションでは、短文蒸留を外すと短文性能が落ち、short-to-long 蒸留を外すと長文性能が落ちる。層選択についても、全層蒸留や最後の層だけの蒸留より、KL ダイバージェンスに基づく少数層の蒸留がよい結果を示した。

具体例

たとえば、ある組織が社内文書検索用に、8K トークンまで扱える Llama-3 系モデルを 128K トークン対応に広げるとする。長い契約書や議事録を一度に読ませたいので、RoPE の設定を変え、長文コーパスで継続事前学習を行う。入力例としては、128K 近い議事録の末尾に「第3四半期の監査担当者は誰か」と質問を置き、文書の前半にある担当者名を答えさせるようなものがある。この長文課題だけを見れば、モデルは以前より長い範囲を読めるようになる。しかし同じモデルに「Python でリストの重複を取り除く関数を書け」や「水は何度で沸騰するか」といった短い入力を与えると、継続事前学習前より答えが不安定になることがある。LongReD では、長い議事録を読む訓練と並行して、短いコード問題や知識問題を元のモデルにも拡張後モデルにも読ませる。拡張後モデルの中間表現が元のモデルから大きく離れている層では、その表現を近づけるように学習する。さらに、短い文章の位置番号を飛ばして、文脈の先頭と末尾に離れて置かれたような状態を作り、長い位置でも元の処理が崩れにくいようにする。期待される出力は、長い議事録では該当箇所から監査担当者名を取り出し、短い Python 問題では簡潔な関数を正しく返すことである。間違えやすい点は、長文を読めるようになったことと、短文の基本能力が保たれていることを同じだと思ってしまうことである。本論文は、その二つを分けて測り、両方を守る訓練目的を置く。