TESS 2: A Large-Scale Generalist Diffusion Language Model
- TESS 2 は、自己回帰(AR)型ではなく diffusion によって文を生成する、汎用の instruction-following language model である。
- 著者らは Mistral 7B を diffusion 形式へ適応させ、さらに instruction tuning を行うことで、既存の diffusion language model を大きく上回る性能を示した。
- SQuAD や TriviaQA などの QA では AR モデルに近い水準に達する一方、GSM8k や BBH のような推論中心の課題では差が残る。
- inference-time compute を diffusion step 数で調節でき、reward model の勾配で出力を誘導できる点が、AR 型とは異なる実用上の特徴である。
Abstract(日本語訳)
本論文では TESS 2 を導入する。TESS 2 は、一般的な instruction-following を行う diffusion language model であり、同時期の instruction-tuned diffusion model を上回り、強力な自己回帰(AR)モデルに匹敵し、ときにはそれを上回る。TESS 2 は、まず強力な AR モデルを diffusion loss による継続事前学習で適応させ、その後さらに instruction tuning を行って訓練する。良い instruction-following diffusion model を訓練するには、適応訓練と base model の選択が重要であることが分かった。さらに本論文では reward guidance を提案する。これは、基礎となるモデルを訓練し直すことなく、inference time に出力をアラインメントするための新しいモジュール型の guidance 手法である。最後に、TESS 2 は inference-time compute を増やすことでさらに改善することを示し、diffusion language model が inference time に用いる計算量を細かく制御できる有用性を明らかにする。
論文の面白いところ
現在の大規模言語モデルの多くは、左から右へ次のトークンを順に予測する AR 型である。TESS 2 はその常識から少し離れ、文全体のノイズを段階的に取り除く diffusion の枠組みで instruction-following を扱う。画像生成では自然な考え方である diffusion を、長いテキスト生成にも実用的な規模で持ち込もうとする点に、この論文の主眼がある。単に小さな実験モデルを作ったのではなく、Mistral 7B を土台にして 45B トークン規模の適応訓練を行い、Tulu 系の instruction data で仕上げている。結果として、従来の diffusion language model が苦手としていた通常の QA や instruction-following で、かなり読める出力を出すところまで来ている。ただし、推論能力が AR 型を全面的に置き換えるという話ではない。むしろ、QA や長めの回答生成では近づくが、数学や複雑な reasoning ではなお弱い、という境界を丁寧に示している。もう一つ興味深いのは、reward guidance を訓練後の推論だけで差し込める点である。これは RLHF のように本体を再訓練するのではなく、生成途中の状態を reward model の勾配で少しずつ望ましい方向へ動かす方法である。diffusion 型ならではの制御性を、言語モデルのアラインメントや生成コストの調整に使える可能性を示している。
問題設定
AR 型の言語モデルは、各時点で次のトークンを一つずつ決めるため、実装と訓練が分かりやすく、現在の標準になっている。一方で、長い出力ではトークン数に応じて forward pass が増え、途中で決めた語を後から全体として直すことも難しい。近年の reasoning model は test-time compute を増やして性能を上げるが、多くの場合は長い chain-of-thought を生成するため、生成コストも大きくなる。diffusion language model は、文全体をノイズのある状態から段階的に復元するので、生成過程を別の形で制御できる。たとえば、step 数を増やして精度を上げる、少なくして速く返す、といった調整が自然にできる。しかし従来の diffusion language model は、規模が小さいか、perplexity のような内部指標に重点が置かれ、実際の instruction-following benchmark で AR モデルと比べられる段階には十分達していなかった。この論文は、既存の強い AR モデルを diffusion model に作り替えれば、汎用的な指示応答にも使えるのかを問う。さらに、その変換で何が効くのか、base model の選択、masking、attention、label shifting、instruction tuning を分けて調べる。問題は単なる生成方式の置換ではなく、AR で得た知識を diffusion の生成過程へどれだけ保てるかにある。
提案手法
TESS 2 は simplex diffusion language model を基礎にしている。トークンを離散 ID のまま扱うのではなく、語彙上の simplex 表現に写し、そこへノイズを加え、モデルに元のトークンを復元させる。訓練時の損失には、通常の diffusion でよく使われる mean squared error ではなく、トークン予測の cross-entropy を用いる。著者らはまず AR モデルを diffusion 用に適応させるため、UL2 masking を使って span infilling と prefix completion の両方を学習させる。span infilling は文中の一部を埋める訓練であり、prefix completion は与えられた前半から後半を生成する訓練である。次に label shifting を入れ、AR モデルが持っていた次トークン予測の性質に近い形で学習を進める。attention については causal mask を外し、出力列全体に対する双方向 attention を使う。これにより、diffusion model が文全体の情報を使って復元する性質を保つ。適応後は Tulu 2 SFT mixture などで instruction tuning を行い、ユーザー指示に対する回答を生成できるようにする。さらに reward guidance では、各 diffusion step の途中出力を reward model に通し、その reward が高くなる方向へ logits を勾配で更新する。この処理は訓練ではなく推論時に行われるため、別の reward model を差し替えることもできる。
結果
base model の比較では、Mistral v0.1 が RoBERTa、Llama 2、Llama 3 より適応しやすい結果になった。特に Llama 系は短い適応訓練では coherent な文を出しにくく、双方向 attention への切り替えが内部表現を大きく変えた可能性があると著者らは述べている。Mistral を 2048 トークン文脈で長く適応させると、perplexity は下がり、Mauve などの生成品質指標も良好に保たれた。instruction tuning 後の評価では、TESS 2 v0.3 が diffusion model の比較対象を全体に上回った。AlpacaEval、SQuAD、TriviaQA では、AR 版 Mistral に近い、または一部で上回る値を示した。たとえば TESS 2 v0.3 は TriviaQA で 53.8 を記録し、表中の Mistral v0.3 AR の 36.7 を上回っている。一方、BBH と GSM8k では差が残り、TESS 2 v0.3 の BBH は 10.8、GSM8k は 36.5 で、推論中心の課題では AR 型の強さが残る。GSM8k symbolic dataset で追加ファインチューニングした場合には、TESS 2 が AR 版 Mistral を上回る結果も得られた。これは、diffusion model でも十分な領域特化データがあれば性能を伸ばせることを示している。reward guidance は AlpacaEval で約 3〜4 ポイントの改善をもたらしたが、guidance weight を上げすぎると意味のない出力に寄る。inference step 数を増やす実験では、GSM8k は 1000 step まで改善傾向を示し、AlpacaEval は 500 step を超えると反復が増えて評価が落ちた。速度面では、素朴な Hugging Face 実装でも、最大 2048 トークンの設定で diffusion model が AR 版より短い時間で batch を処理したと報告している。
具体例
たとえば、ユーザーが「Mr Sinister とは誰か」と短く尋ねたとする。AR 型なら、まず「He」を出し、次に「is」を出し、その後も一語ずつ前の出力に基づいて文を伸ばしていく。TESS 2 では、最初は回答全体がノイズを含む未確定の列として置かれ、diffusion step を重ねるごとに「He is an X-men villain.」のような文へ近づく。途中段階では語が崩れていたり、同じ語が反復したりするが、後半の step で全体が整えられる。reward guidance を使う場合、各 step の候補回答を reward model が評価し、より有用な回答に見える方向へ生成状態を少し動かす。この例では、単に名前だけを返すより、「X-Men の悪役である」と短く説明する回答が期待される。間違えやすい点は、diffusion が文全体を同時に扱うからといって、事実関係や推論が自動的に正しくなるわけではないことである。論文中の評価でも、TESS 2 は QA では比較的よいが、数学や多段推論では基本的な計算や論理を誤ることがある。したがって現時点での利用場面は、厳密な推論よりも、回答生成、検索結果に基づく QA、長い文の並列的な生成などに近い。step 数や reward guidance を調整できることは利点だが、過度に強く誘導すると reward hacking に似た崩れた出力が起きる。