Terminal Velocity Matching: 終端速度による分布レベルの保証

Terminal Velocity Matching(TVM)は、Flow Matching を一般化し、任意の2つの拡散タイムステップ間の遷移をモデル化する手法である。MeanFlow が開始時刻 \(t\) で微分するのに対し、TVM は終端時刻 \(s\) で微分することで、生成分布とデータ分布の間の 2-Wasserstein 距離に対する明示的な上界を導出する。この理論的保証は MeanFlow にはない TVM 固有の貢献であり、実用上もアーキテクチャの Lipschitz 連続性を制御することで訓練安定性の向上につながっている。

背景と動機

Flow Matching から変位マップへ

標準的な Flow Matching は、瞬間速度場 \(v(z_t, t)\) を学習し、ODE を多ステップで解くことでサンプリングを行う。1ステップ生成を目指す場合、ODE の全軌道を1回のネットワーク評価で近似する必要があるが、軌道の曲率が大きい場合にこの近似は困難である。

MeanFlow はこの問題に対し、「平均速度(average velocity)」という量を導入した。平均速度は時刻 \(t\) から時刻 \(0\) までの変位を時間間隔で割ったものであり、開始時刻 \(t\) に関する微分条件(MeanFlow Identity)を通じて訓練される。

TVM はこのアイデアをさらに一般化し、変位マップ(displacement map) \(f(x_t, t, s)\) を導入する。これは時刻 \(t\) の状態 \(x_t\) から時刻 \(s\) の状態 \(x_s\) への変位を直接モデル化するものであり、\(t\)\(s\) の任意のペアを扱える。

開始時刻 vs 終端時刻での微分

MeanFlow と TVM の根本的な違いは、微分の方向にある。

  • MeanFlow: 開始時刻 \(t\) で微分し、MeanFlow Identity を導出
  • TVM: 終端時刻 \(s\) で微分し、Terminal Velocity 条件を導出

この違いは単なる技術的選択ではなく、理論的保証に直結する。終端時刻での微分は、変位マップが終端で正しい速度場に一致することを要求し、これにより分布レベルでの誤差の上界が導出可能となる。一方、開始時刻での微分からはこのような上界は直接得られない。

理論的枠組み

変位マップの定義

確率フロー ODE の解を \(\phi^{t \to s}(x_t)\) と書く。これは時刻 \(t\) の点 \(x_t\) を時刻 \(s\) まで流した結果である。変位マップ \(f^{t \to s}(x_t)\) は以下のように定義される:

\[ f^{t \to s}(x_t) = \phi^{t \to s}(x_t) - x_t \tag{1}\]

すなわち、変位マップは ODE 軌道上での正味の変位を表す。

スケーリングされたパラメータ化

ニューラルネットワーク \(F_\theta\) を用いて、変位マップを以下のようにパラメータ化する:

\[ f_\theta(x_t, t, s) = (s - t) \cdot F_\theta(x_t, t, s) \tag{2}\]

このスケーリングは、\(s \to t\) のとき変位が \(0\) に近づくべきであるという自然な境界条件を反映している。\(F_\theta\) 自体は有界な出力を生成すればよいため、学習が安定化する。

Terminal Velocity 条件

TVM の中核をなすのが、終端時刻 \(s\) に関する微分条件である。変位マップの終端時刻 \(s\) での微分は、速度場 \(v(x_s, s)\) に一致しなければならない:

\[ \frac{\partial f^{t \to s}(x_t)}{\partial s}\bigg|_{x_s = x_t + f^{t \to s}(x_t)} = v(x_s, s) \tag{3}\]

この条件は、変位マップが終端で正しい瞬間速度に接続することを保証する。直感的には、「終端速度が正しければ、軌道全体の精度も制御できる」という洞察に基づいている。

損失関数

TVM の損失関数は、2つの項の和として構成される:

\[ \mathcal{L}_{\text{TVM}} = \underbrace{\mathbb{E}_{t,s}\left[\lambda_{\text{TV}}(s) \left\|\frac{\partial f_\theta}{\partial s}(x_t, t, s) - v_\theta(x_t + f_\theta(x_t, t, s), s)\right\|^2\right]}_{\text{Terminal Velocity Term}} + \underbrace{\mathbb{E}_{t}\left[\lambda_{\text{FM}}(t) \left\|v_\theta(x_t, t) - u(x_t | x_0)\right\|^2\right]}_{\text{Flow Matching Term}} \tag{4}\]

Terminal Velocity TermEquation 3 の条件を最小化する項であり、変位マップの終端微分が速度場に一致するよう促す。Flow Matching Term は標準的な Flow Matching の損失であり、速度場自体の精度を保証する。\(\lambda_{\text{TV}}(s)\)\(\lambda_{\text{FM}}(t)\) はそれぞれの重み関数である。

Wasserstein 距離の上界

TVM の最も重要な理論的貢献は、生成分布とデータ分布の間の 2-Wasserstein 距離に対する明示的な上界を与える定理である。

主定理

\(f_\theta^{t \to 0}\) を学習された変位マップ(時刻 \(t\) から \(0\) への写像)、\(p_t\) を時刻 \(t\) での分布、\(p_0\) をデータ分布とする。\(f_\theta\)\(x\) に関して Lipschitz 連続であるとき、以下が成り立つ:

\[ W_2^2\!\left(f_\theta^{t \to 0} \# \, p_t,\; p_0\right) \leq \int \lambda(s)\, \mathcal{L}_{\text{TVM}}(s)\, ds + C \tag{5}\]

ここで \(f_\theta^{t \to 0} \# \, p_t\)\(p_t\) を変位マップで押し出した(pushforward)分布であり、\(\lambda(s)\) は重み関数、\(C\) は定数である。

意味と重要性

この定理は、以下の点で重要である:

  • 訓練損失と生成品質の直接的な関係: TVM 損失を小さくすれば、生成分布がデータ分布に近づくことが保証される
  • 分布レベルの保証: 個々のサンプルではなく、分布全体の近さを制御できる
  • MeanFlow にはない保証: MeanFlow の定式化からは同様の上界は導出されていない

Lipschitz 連続性の必要性

Equation 5 の成立には \(f_\theta\) の Lipschitz 連続性が不可欠である。Lipschitz 連続性が保証されない場合、変位マップの微小な入力変化に対して出力が過大に変動し、上界が無限大に発散し得る。この条件は理論的な仮定にとどまらず、実装上のアーキテクチャ設計に直接的な制約を課す(Section 1.4 を参照)。

アーキテクチャ修正

標準 DiT の問題点

Wasserstein 上界の成立に必要な Lipschitz 連続性は、標準的な DiT(Diffusion Transformer)アーキテクチャでは満たされない。具体的には、以下の2つのコンポーネントが Lipschitz 連続性を破壊する:

  • LayerNorm: 入力のノルムが小さい領域で勾配が爆発し得る
  • Scaled Dot-Product Attention: ソフトマックスの指数関数的な振る舞いにより、入力の微小な変化が出力を大きく変動させ得る

Semi-Lipschitz 制御

TVM では、完全な Lipschitz 制約(スペクトル正規化など)を課す代わりに、Semi-Lipschitz と呼ばれる緩和された制御を採用している。これは理論的に厳密な Lipschitz 定数の制御ではなく、実践的に十分な安定性を提供する設計である。

RMSNorm への置き換え:

  • LayerNorm を RMSNorm(Root Mean Square Normalization)に置き換える
  • RMSNorm は平均の引き算を行わないため、入力のノルムが小さい領域での勾配爆発が緩和される
  • Lipschitz 定数が制御可能になる

QK-normalization:

  • Self-Attention のクエリ(Q)とキー(K)に対して正規化を適用する
  • 内積のスケールが制御され、ソフトマックスの出力が安定化する
  • 結果として Attention の Lipschitz 連続性が改善される

カスタム Flash Attention カーネル

Terminal Velocity Term(Equation 4)の計算には、変位マップの終端時刻 \(s\) に関するヤコビアンベクトル積(JVP: Jacobian-Vector Product)が必要である。標準的な自動微分では、Attention 層を通じた JVP の計算はメモリ効率が悪い。

TVM では、JVP 計算を Attention の forward pass と融合したカスタム Flash Attention カーネルを開発している。これにより:

  • メモリ使用量を削減(中間活性値の保存が不要)
  • 計算速度を向上(カーネル融合による高速化)
  • 最大65%の高速化を達成(標準的な自動微分と比較)
┌──────────────────────────────────────────────────────────────┐
│  Architecture Modifications for Semi-Lipschitz Control       │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  Standard DiT              TVM-Modified DiT                  │
│  ──────────────            ─────────────────                 │
│  LayerNorm          --->   RMSNorm                           │
│  Dot-Product Attn   --->   QK-Normalized Attn                │
│  Standard Autograd  --->   Fused Flash Attn (JVP)            │
│                                                              │
│  Result: Semi-Lipschitz continuity + 65% speedup             │
└──────────────────────────────────────────────────────────────┘

Classifier-Free Guidance の統合

CFG の課題

Classifier-Free Guidance(CFG)は、条件付き生成の品質を向上させる標準的なテクニックであるが、1ステップ生成モデルとの統合には固有の課題がある。標準的な CFG では、条件付き予測と無条件予測の線形結合を用いる:

\[ \tilde{v}(x_t, t, c) = v(x_t, t) + w \cdot \left(v(x_t, t, c) - v(x_t, t)\right) \]

ここで \(w\) は CFG スケール、\(c\) は条件(クラスラベルなど)である。\(w\) が大きいほど条件への忠実度は高まるが、多様性は低下する。

スケーリングされた CFG パラメータ化

TVM では、CFG スケール \(w\) を変位マップのスケーリングに組み込んだパラメータ化を採用している。標準的な CFG が速度場に対して適用されるのに対し、TVM の CFG は変位マップ全体に対して作用する。

勾配安定性のための重み付け

CFG スケール \(w\) が大きい場合、Terminal Velocity Term の勾配が不安定になる問題がある。TVM ではこの問題に対し、\(1/w^2\) の勾配重み付けを導入している:

\[ \lambda_{\text{TV}}(s, w) = \frac{1}{w^2} \cdot \lambda_{\text{TV}}(s) \tag{6}\]

この重み付けにより、高い CFG スケールでも勾配のスケールが制御され、訓練が安定化する。直感的には、CFG スケールが大きいほど変位の絶対値が大きくなるため、終端微分の誤差も増幅される。\(1/w^2\) の重みはこの増幅を打ち消す役割を果たす。

ランダム CFG サンプリング

訓練時には、各ミニバッチで CFG スケール \(w\) をランダムにサンプリングする。これにより、モデルは様々な CFG スケールに対して同時に学習し、推論時に任意の \(w\) を選択できるようになる。固定の \(w\) で訓練するよりも汎化性能が向上することが実験的に確認されている。

実験結果

ImageNet 256x256

Table 1: ImageNet 256x256 での FID スコア
設定 NFE FID 備考
TVM (\(w=2\)) 1 3.29 1-NFE での最良結果
TVM (\(w=1.5\)) 2 2.47 -
TVM (\(w=1.3\)) 4 1.99 4-NFE での最良結果
DiT (FM) 250 2.27 ベースライン

1-NFE で FID 3.29 を達成しており、これは MeanFlow の 3.43 を上回る。さらに注目すべきは 4-NFE での FID 1.99 であり、250-NFE の DiT(FID 2.27)を大幅に上回っている。

ImageNet 512x512

Table 2: ImageNet 512x512 での FID スコア
設定 NFE FID
TVM (\(w=2\)) 1 4.32
TVM (\(w=1.3\)) 4 2.94

高解像度でも TVM は有効に機能し、1-NFE で FID 4.32、4-NFE で FID 2.94 を達成している。

アブレーション

TVM の性能に影響を与える主要なハイパーパラメータについて、アブレーション実験が行われている。

時間サンプリング:

  • 訓練時の \((t, s)\) ペアのサンプリング戦略が性能に大きく影響する
  • \(s\)\(t\) の近傍に集中させるサンプリングが有効であり、遠い \((t, s)\) ペアは学習信号が弱い

EMA rate(指数移動平均の減衰率):

  • ターゲットネットワークの EMA rate は性能に敏感である
  • 高すぎる EMA rate はターゲットの更新が遅くなり、古い情報に基づいて訓練される
  • 低すぎる EMA rate はターゲットが不安定になる

スケーリング:

  • Equation 2 のスケーリング \((s-t)\) が重要であり、スケーリングなしでは訓練が不安定になる
  • スケーリングにより、短い時間間隔での変位が自動的に小さくなり、学習が容易になる

MeanFlow との比較

TVM と MeanFlow は共に Flow Matching を拡張して1ステップ生成を実現する手法であるが、いくつかの根本的な違いがある。

微分の方向:

  • MeanFlow: 変位マップの開始時刻 \(t\) に関する微分条件(MeanFlow Identity)を用いる。\(\frac{\partial}{\partial t}\left[\frac{f(x_t, t)}{t}\right]\) の形式で、平均速度の変化率を制御する。
  • TVM: 変位マップの終端時刻 \(s\) に関する微分条件(Terminal Velocity)を用いる。\(\frac{\partial f}{\partial s}(x_t, t, s)\) の形式で、終端での速度を制御する。

CFG 下での勾配安定性:

  • MeanFlow: CFG スケール \(w\) が大きい場合、開始時刻での微分が不安定になり得る。特に \(t \to 0\) の近傍で勾配が発散する問題が報告されている。
  • TVM: \(1/w^2\) の勾配重み付け(Equation 6)により、高い CFG スケールでも安定した訓練が可能である。この重み付けは終端時刻での微分構造から自然に導出される。

Wasserstein 上界の有無:

  • MeanFlow: 生成分布とデータ分布の距離に関する明示的な上界は導出されていない。訓練損失が小さくても、分布レベルでの近さは理論的に保証されない。
  • TVM: Equation 5 により、訓練損失と Wasserstein 距離の間に直接的な関係が確立されている。ただし、この上界は Lipschitz 連続性の仮定に依存する。

性能比較(ImageNet 256x256, 1-NFE):

  • MeanFlow: FID 3.43
  • TVM: FID 3.29

TVM は MeanFlow を上回っているが、差は比較的小さい。TVM の真の強みは、少数ステップ(4-NFE: FID 1.99)での性能改善と理論的保証の存在にある。

詳細: MeanFlow

訓練-推論トレードオフ

TVM の実験結果は、CFG スケールと NFE の間に興味深いトレードオフが存在することを示している。

高 CFG(\(w=2\))の場合:

  • 1-NFE では最良の FID(3.29)を達成
  • しかし 2-NFE に増やすと FID が悪化する場合がある

低 CFG(\(w=1.3\))の場合:

  • 1-NFE では FID が高い(品質が劣る)
  • しかし 4-NFE では最良の FID(1.99)を達成

この現象は、モデルの容量制限を示唆している。高い CFG スケールでは、モデルは1ステップで「強い」補正を行うように訓練されるが、この補正は粗い近似に留まる。2ステップ目を追加すると、1ステップ目の粗い補正の上にさらに補正を重ねることになり、かえって精度が低下する。

一方、低い CFG スケールでは各ステップの補正が穏やかであるため、ステップ数を増やすことで精度が単調に向上する。

┌───────────────────────────────────────────────────────────────┐
│  CFG Scale vs NFE Trade-off                                   │
├───────────────────────────────────────────────────────────────┤
│                                                               │
│  High CFG (w=2):    1-NFE [***] --> 2-NFE [**]  (degrades)    │
│  Medium CFG (w=1.5): 1-NFE [**] --> 2-NFE [***] (improves)    │
│  Low CFG (w=1.3):   1-NFE [*]  --> 4-NFE [****] (best)        │
│                                                               │
│  Key insight: Optimal CFG decreases as NFE increases          │
└───────────────────────────────────────────────────────────────┘

実用上の指針:

  • リアルタイム応用(1-NFE が必須): \(w=2\) を選択
  • 高品質生成(数ステップ許容): \(w=1.3\) で 4-NFE を選択
  • バランス型: \(w=1.5\) で 2-NFE を選択

まとめ

Terminal Velocity Matching は、Flow Matching の一般化として、終端時刻での微分条件を課すことで理論的に強い保証を獲得した手法である。

主な貢献:

  • Wasserstein 上界: 訓練損失と生成品質の間の明示的な関係を確立
  • Semi-Lipschitz アーキテクチャ: RMSNorm と QK-normalization による安定化
  • カスタム Flash Attention: JVP 融合カーネルによる最大65%の高速化
  • CFG 統合: \(1/w^2\) 重み付けとランダム CFG サンプリングによる安定訓練
  • SOTA 性能: ImageNet 256x256 で FID 3.29(1-NFE)、FID 1.99(4-NFE)

TVM の理論的枠組みは、1ステップ生成モデルにおいて「なぜ訓練損失の最小化が生成品質の向上に寄与するのか」という根本的な問いに対する数学的な回答を提供している。Lipschitz 連続性という条件は実装上の制約を伴うが、Semi-Lipschitz 制御という実践的な妥協点が有効に機能することを実験的に示している。