AutoMixer: Checkpoint Artifacts as Automatic Data Mixers
- 言語モデルの事前学習で、訓練途中のチェックポイントをデータ選別器として用いる AutoMixer を提案する。
- 各チェックポイントが得意とする推論ベンチマークを手がかりに、生データをタスク寄りの群へ分け、影響度に基づいてサンプリング比率を定める。
- FineWeb-Edu を用いた実験では、350M の代理モデルを使う設定が一様サンプリングを平均 1.93 ポイント上回った。
論文の面白いところ
この論文は、ふつうは保存物として扱われる訓練途中のチェックポイントを、データ混合のための道具として使う。大規模言語モデルの訓練では、どのデータをどれだけ読ませるかが性能に強く関わる。しかし、ある推論能力を伸ばすデータがどれであるかは、あらかじめ明瞭には分からない。AutoMixer は、この不明瞭さを、訓練の途中で現れる能力の変化から読む。ある時点のチェックポイントが ARC や BoolQ などの課題でよく働くなら、その時点のモデルはその課題に関わるデータの信号を持つとみなす。そこから各サンプルの影響度を近似し、データを再編成する。発想は素朴だが、既存の訓練ログや保存チェックポイントを有効に使う点に実用味がある。最終モデルだけを見てデータを選ぶのではなく、訓練の道中を観察対象にしている点もよい。論文中の結論は過度に一般化できないが、事前学習データの混合を手作業の領域から少し外へ出す試みとして読める。
問題設定
言語モデルの事前学習では、多様なデータを混ぜて、複数の能力を同時に育てる。ところが、タスクに対応するデータ領域はしばしば曖昧である。たとえば、常識推論に効く文書は、単に科学記事や教科書というラベルだけでは決まらない。領域どうしは重なり、同じ文書が複数の能力に関係することもある。反対に、見かけ上は関連していても、訓練上の効果が乏しいサンプルもあり得る。総当たりで多くのデータ混合比を試す方法は、モデルが大きくなるほど費用が合わない。既存の影響関数にも手がかりはあるが、単一の時点だけで見ると、訓練中に移り変わる能力を取り逃がす。論文はこの問題を、データの群分けと、その群へのサンプリング重み付けという二つの課題として定式化する。目的は、限られたトークン予算のもとで、目標タスクに役立つデータをより多く読ませることである。
提案手法
AutoMixer は、まず小さめの代理モデルを訓練し、その途中で保存されたチェックポイントを評価する。評価には ARC-easy、ARC-challenge、BoolQ、PIQA、SIQA、HellaSwag、OpenBookQA、WinoGrande の八つの常識推論ベンチマークを用いる。各タスクで最もよい性能を示したチェックポイントを選び、そのチェックポイントをサンプル影響度の推定に使う。影響度の計算には DataInf に基づく一次近似を用い、ヘッセ行列の明示的な逆行列計算を避ける。計算量を抑えるため、勾配は主に埋め込み層と最終層から取る。次に、各サンプルについて複数チェックポイントからの影響度を集約する。訓練の遅い段階で伸びるタスクほど、チェックポイントのステップ数に由来する係数を通じて重みを持つ。こうして得た値で生データをタスク寄りの群に分け、上位のサンプルを残す。最後に、各群の影響度密度を求め、その比率に従って事前学習時のサンプリング重みを決める。
結果
実験では FineWeb-Edu を事前学習データとして用い、Llama 3 系のデコーダ専用 Transformer を訓練している。比較対象は、一様サンプリング、パープレキシティに基づくサンプリング、n-gram に基づくサンプリングである。AutoMixer には 75M と 350M の代理モデルを用いた二つの設定がある。350M の代理モデルを用いた AutoMixer は、350M の本訓練モデルで一様サンプリングを平均 1.93 ポイント上回った。1.5B では平均 1.22 ポイント、3B では平均 1.05 ポイントの改善であった。個別課題でも、多くの場合に他のサンプリング法より高い値を示している。一方で、75M の代理モデルを用いた AutoMixer はほぼ改善せず、しばしば一様サンプリングを下回った。これは、代理モデルが小さすぎると高影響サンプルの識別が不十分になることを示す。チェックポイント選択の比較では、最終チェックポイントだけを用いる方法や全チェックポイントをまとめる方法より、タスクごとに選んだチェックポイントを使う方法がよい結果を示した。ただし、影響度推定には追加の計算費用があり、異なる領域や大規模な本番訓練にそのまま移せるかは残された課題である。
具体例
常識推論を伸ばすために、ある研究者が FineWeb-Edu から事前学習データを作る場面を考える。入力には、物理現象を説明する短い教材、日常会話の断片、雑多なウェブ記事が混じっている。代理モデルを 100,000 ステップほど訓練すると、途中のあるチェックポイントでは PIQA のような物理常識問題がよく解け、別のチェックポイントでは BoolQ の真偽質問がよく解ける。AutoMixer は、それぞれのチェックポイントを使って、どの文書がその課題の検証損失を下げる向きに働くかを近似する。たとえば「金属は熱をよく伝えるので、熱い鍋の取っ手には布を使う」といった文は、物理常識に関わる群で高い影響度を得るかもしれない。反対に、語彙は似ていても広告文のように推論の手がかりが薄い文書は、低い順位に置かれることがある。期待される出力は、文書ごとの分類ラベルではなく、複数のデータ群とそれぞれのサンプリング比率である。間違えやすい点は、影響度が高い文書を単純に「品質が高い文書」と解釈することである。この手法でいう価値は、対象タスクと特定の訓練段階に対する価値であり、一般的な文章の良否とは同じではない。