One-for-All Pruning: A Universal Model for Customized Compression of Large Language Models
- LLM の pruning は、単一の圧縮条件なら有効な方法が多いが、多数の利用者が異なる圧縮率や性能条件を求める場合、探索を毎回やり直す方法は遅くなる。
- 本論文は、圧縮要求を pruning 戦略へ写す StratNet と、評価を近似する Gaussian process を組み合わせた UniCuCo を提案する。
- Mistral-7B で 64 件の圧縮要求を処理する実験では、最適化ベースの手法より少なくとも 28 倍速く、精度は同程度に保たれた。
Abstract(日本語訳)
既存の大規模言語モデル(LLM)向け pruning 手法は、モデル性能を維持しながら高い圧縮率を達成することに焦点を置いている。これらの手法は、単一の利用者による圧縮要求を扱う場合には十分な性能を示してきたが、要求数に比例して処理時間が増えるため、複数の要求が同時に生じる実際の場面では効率が悪い。この制約に対処するため、本論文では LLM のための Universal Model for Customized Compression(UniCuCo)を提案する。UniCuCo は、任意の要求をその最適な pruning 戦略へ写像することを学習する StratNet を導入する。StratNet の学習には、pruning 戦略を評価する計算コストが高いこと、また pruning 過程が微分不可能であるため StratNet 更新のための勾配逆伝播が妨げられること、という難点がある。これらの難点を克服するため、著者らは Gaussian process を用いて評価過程を近似する。Gaussian process の勾配は計算可能であるため、微分不可能な pruning 過程の勾配を近似するためにそれを用いることができ、これにより StratNet の更新が可能になる。実験結果は、UniCuCo が 64 件の要求を処理する際にベースラインより 28 倍速く、同時にベースラインと同程度の accuracy を維持することを示している。
論文の面白いところ
この論文の焦点は、単に LLM を小さくすることではなく、圧縮要求が多数ある状況で pruning 戦略をどう素早く返すかにある。通常の pruning 研究では、ある圧縮率のもとで性能をどこまで保つかが主題になりやすい。これに対し本論文は、クラウド側が多様な端末や利用条件に応じて、多数の圧縮案を出す場面を前提に置く。1 件ずつ探索する手法は精度面で強いが、利用者数が増えると処理時間もほぼ線形に増える。UniCuCo は、要求から戦略への写像を一度学習しておき、以後は要求を入力するだけで戦略を返す。ここが、個別探索の代替として扱いやすい点である。さらに、Gaussian process を単なる高速な評価器として使うだけでなく、微分不可能な pruning 操作を迂回するための勾配近似にも使っている。論文中には UniCuCo を ReCoP や PuCC と呼ぶ箇所も見えるが、中心となる仕組みは StratNet と Gaussian process の組で一貫している。
問題設定
LLM compression は、事前学習済みモデルのサイズや計算量を下げつつ、言語モデルとしての性能をなるべく保つことを目的とする。pruning はその代表的な方法で、重要度の低い重みや層を取り除く。端末ごとのメモリ、推論速度、許容できる性能低下は同じではないため、実運用では一つの圧縮条件だけで足りない。たとえば、ある利用者は 25% の sparsity で性能を重視し、別の利用者は 70% の sparsity でサイズ削減を重視する。既存手法のうち、EvoPress のような最適化ベースの方法は、要求ごとに探索を行うため性能を保ちやすいが遅い。ShortGPT や Weight Subcloning などのスコアベースの方法は速いが、層の重要度スコアだけでよい pruning 順序が必ず得られるとは限らない。本論文は、複数の圧縮要求に対して、探索ベース手法に近い品質を保ちつつ、スコアベース手法に近い速度で応答する問題を扱う。
提案手法
UniCuCo は、圧縮要求を二つの目的の重みとして表す。ひとつはモデルサイズ削減、もうひとつは性能保持であり、性能保持は元の LLM と pruning 後の LLM の出力分布の KL divergence で測る。StratNet は、この要求ベクトルを入力として受け取り、各 transformer block をどの程度 pruning するかを表す戦略を出力する。出力が 0/1 の場合は block 単位で層を落とす depth pruning になり、連続値の場合は層ごとに sparsity を変える non-uniform pruning になる。二目的最適化には weighted Tchebycheff function を用いる。これは、単純な重み付き和では扱いにくい非凸な Pareto front にも対応しやすいためである。
問題は、ある pruning 戦略が性能をどれだけ保つかを評価するには、実際に pruning したモデルを作り、calibration dataset で推論しなければならない点である。この処理を毎回行うと StratNet の学習が重くなる。また pruning 戦略から pruning 後モデルへの操作は微分できないため、そのままでは勾配で StratNet を更新できない。そこで UniCuCo は Gaussian process を学習し、pruning 戦略から性能劣化の指標を予測する。Gaussian process は予測値だけでなく不確実性も返すため、LCB や UCB により、よさそうな領域と未確認の領域を適度に探索できる。さらに Gaussian process の勾配を用いることで、微分不可能な pruning 操作を直接通らずに StratNet を更新する。学習中は StratNet と Gaussian process を交互に更新し、hypervolume improvement が大きい候補を追加評価して Gaussian process の訓練データを増やす。
結果
実験では、depth pruning と non-uniform pruning の両方を扱っている。depth pruning では Mistral-7B と Llama-3-8B を中心に、WikiText-2、C4、FineWeb-Edu 上の perplexity を評価した。UniCuCo は多くの条件で、ShortGPT や Weight Subcloning などの高速なスコアベース手法より低い perplexity を示した。EvoPress は一部条件で最良の perplexity を示すが、1 件の要求につきおよそ 13 分から 26 分を要したのに対し、UniCuCo は 1 秒未満で戦略を返した。要求数を 64 件まで増やした比較では、Mistral-7B で UniCuCo の総時間は EvoPress より 56 倍短いと報告されている。abstract では、ベースラインに対して 28 倍高速で同程度の accuracy を保つと要約されている。
non-uniform pruning では、Mistral-7B の 50%、60%、70% sparsity を評価し、WikiText-2 と C4 の perplexity に加えて、ARC、HellaSwag、PiQA、WinoGrande などの zero-shot accuracy を見ている。UniCuCo は Uniform より平均 accuracy が高く、70% sparsity では 3% の改善を示した。EvoPress は accuracy で強い場合があるが、単一要求でも約 120 分を要する。OWL は EvoPress より速いものの、UniCuCo より accuracy が低い。Gaussian process を動的に更新しない設定では性能が下がり、推定器を更新し続けることが手法の要点であることも示されている。
具体例
ある企業が Mistral-7B を複数の環境に配布するとする。営業用のノート PC では応答品質をなるべく保ちたいので、担当者は「サイズ削減は中程度、性能保持を重視」という要求を出す。一方、工場内の小型端末ではメモリ制約が厳しく、「性能低下を少し許容しても 70% 近い sparsity が欲しい」という要求になる。UniCuCo では、これらの要求を二つの目的の重みとして StratNet に入力する。StratNet は、どの transformer block を残すか、または各 block をどの sparsity にするかを表す pruning 戦略を返す。Gaussian process は、その戦略を実際に毎回試さなくても、元のモデルとの出力分布のずれを近似して評価する。期待される出力は、端末ごとに異なる pruning 設定であり、同じ LLM から複数の圧縮版を速く作るための設計図にあたる。間違えやすい点は、UniCuCo が圧縮済みモデルそのものを直接生成する万能な圧縮器ではなく、要求に応じた pruning 戦略を高速に出す枠組みである点である。実際の pruning や評価には calibration dataset と対象モデルが必要であり、論文の制限でも、より大きな block 数をもつ LLM や複数 LLM を同時に扱う場合には追加検証が必要だと述べられている。