<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://liyongzhi.xyz/feed.xml" rel="self" type="application/atom+xml" /><link href="https://liyongzhi.xyz/" rel="alternate" type="text/html" /><updated>2026-05-10T00:32:04+08:00</updated><id>https://liyongzhi.xyz/feed.xml</id><title type="html">李勇志 (Yongzhi Li)</title><subtitle>Personal website of Yongzhi Li, sharing research, projects, and technical writing on multimodal generation, large language models, AI agents, and computer vision.</subtitle><author><name>李勇志 (Yongzhi Li)</name><email>yongzhili@pku.edu.cn</email></author><entry><title type="html">Diffusion 模型的条件注入演进史：从通道拼接到单流 DiT</title><link href="https://liyongzhi.xyz/posts/2026/05/diffusion-condition-injection/" rel="alternate" type="text/html" title="Diffusion 模型的条件注入演进史：从通道拼接到单流 DiT" /><published>2026-05-09T00:00:00+08:00</published><updated>2026-05-09T00:00:00+08:00</updated><id>https://liyongzhi.xyz/posts/2026/05/blog-post-diffusion-condition-injection</id><content type="html" xml:base="https://liyongzhi.xyz/posts/2026/05/diffusion-condition-injection/"><![CDATA[<p>如果你看过 Stable Diffusion、ControlNet、IP-Adapter，又听说过最近的 Qwen-Image 和 Z-Image，可能会有一个共同的疑问：</p>

<blockquote>
  <p>这些模型架构看上去差别很大，但它们要解决的问题其实都是同一个：<strong>怎么把”用户想要什么”这件事告诉模型？</strong></p>
</blockquote>

<p>文本 prompt、参考图、姿态骨架、深度图、音频、mask……每一种”控制信号”进入网络的方式都不一样。这篇文章想做的事情是：</p>

<blockquote>
  <p>把过去几年 Diffusion 里<strong>主流的条件注入方式</strong>串成一条线，讲清楚每一步的动机、做法、优劣，以及它们最后是怎么汇聚到今天的 DiT 体系里的。</p>
</blockquote>

<p>读完之后，你应该能：</p>

<ul>
  <li>理解为什么 inpainting 用 channel concat，而 text-to-image 用 cross-attention；</li>
  <li>知道 ControlNet 和 IP-Adapter 各自解决的是什么 prompt 解决不了的问题；</li>
  <li>看懂 adaLN-Zero 为什么是 DiT 的”默认调制器”；</li>
  <li>理解 Qwen-Image 的多流 MMDiT 和 Z-Image 的单流 S3-DiT 在条件控制上到底差在哪里；</li>
  <li>对未来 Diffusion 架构的演进方向有一个直觉性的判断。</li>
</ul>

<hr />

<h2 id="1-起点diffusion-模型为什么需要条件">1. 起点：Diffusion 模型为什么需要”条件”</h2>

<p>在最朴素的 DDPM 里，模型学的是数据分布 <code class="language-plaintext highlighter-rouge">p(x)</code>：</p>

<blockquote>
  <p>什么样的样本看起来像”真实世界中的自然图像”。</p>
</blockquote>

<p>它知道怎么把高斯噪声慢慢拉回自然图像流形，但它<strong>不知道你想要什么</strong>。给它噪声，它能生成一张图，但你没法控制是猫是狗、是写实还是油画、是白天还是夜晚。</p>

<p>所以条件 Diffusion 真正学的是：</p>

\[p(x \mid c)\]

<p>也就是”给定条件 <code class="language-plaintext highlighter-rouge">c</code> 的情况下，样本 <code class="language-plaintext highlighter-rouge">x</code> 长什么样”。这里的 <code class="language-plaintext highlighter-rouge">c</code> 可以是文本、参考图、姿态骨架、音频、深度图……几乎任何能数字化表示的”用户意图”。</p>

<p>问题来了：<strong><code class="language-plaintext highlighter-rouge">c</code> 怎么进入模型？</strong></p>

<p>这看似是个工程细节，但它直接决定了模型能不能用这个条件、能用得多准、计算成本多大。过去几年，Diffusion 社区在这个问题上其实经历了好几代演化。我们一个一个看。</p>

<hr />

<h2 id="2-第一代通道拼接-channel-concat">2. 第一代：通道拼接 (Channel Concat)</h2>

<h3 id="21-动机inpainting-是最早的条件生成">2.1 动机：inpainting 是最早的”条件生成”</h3>

<p>最早的”有条件”扩散模型其实就是 <strong>inpainting</strong> —— 给一张图、一个 mask，让模型把 mask 区域补全。这种任务的特点是：</p>

<ul>
  <li>条件本身是<strong>和图像同一空间结构</strong>的（mask 是 H×W 的图，masked image 也是 H×W 的图）；</li>
  <li>用户的意图就是”按照原图的结构和上下文，把这块补好”。</li>
</ul>

<p>那最简单的做法是什么？</p>

<p><strong>直接把 mask 和 masked image 当成额外通道，拼到 noisy latent 上一起送进 U-Net。</strong></p>

<p>Stable Diffusion Inpainting 就是这么干的。原本的 U-Net 输入是 <code class="language-plaintext highlighter-rouge">[B, 4, H, W]</code>（4 个 latent channel），inpainting 版本变成 <code class="language-plaintext highlighter-rouge">[B, 9, H, W]</code> —— 多出来的 5 个通道分别是 1 个 mask 和 4 个 masked latent。</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>unet_input = concat[
    noisy_latent,        # [B, 4, H, W]   被去噪的对象
    mask,                # [B, 1, H, W]   告诉模型哪里要改
    masked_latent,       # [B, 4, H, W]   告诉模型其它地方长什么样
]
</code></pre></div></div>

<h3 id="22-latentsync把这个思路用到-lip-sync">2.2 LatentSync：把这个思路用到 lip-sync</h3>

<p>最近字节开源的 <a href="https://github.com/bytedance/LatentSync">LatentSync</a> 把这个思路扩展到了视频口型同步。它的 U-Net 输入是 13 通道：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>unet_input = concat[
    noisy_gt_latents,    # 4 channels  当前 diffusion step 下的 noisy target
    masks,               # 1 channel   嘴部 mask
    masked_latents,      # 4 channels  嘴部被遮住的当前帧 latent
    ref_latents,         # 4 channels  同一视频里另一段参考帧 latent
]                        # total: 13
</code></pre></div></div>

<p>可以看出来，channel concat 适合的是<strong>和图像在空间上对齐的低层视觉条件</strong>：mask、被遮挡的图像、参考帧。它的优点是简单直接，VAE encoder 的输出可以直接拼上去；缺点是它没法处理<strong>变长</strong>或<strong>异构模态</strong>的条件，比如一段文字、一段音频。</p>

<h3 id="23-局限">2.3 局限</h3>

<p>如果用户的 prompt 是 “a red sports car running on the highway”，你怎么把它”拼”到一张图上？文字根本不是 H×W 的张量。</p>

<p>这就引出了下一代的方案。</p>

<hr />

<h2 id="3-第二代交叉注意力-cross-attention">3. 第二代：交叉注意力 (Cross-Attention)</h2>

<h3 id="31-动机stable-diffusion-让文本成为-prompt">3.1 动机：Stable Diffusion 让文本成为 prompt</h3>

<p>2022 年 Stable Diffusion 一炮走红的关键，不是它的 VAE，也不是它的 U-Net 结构，而是它把<strong>文本→图像</strong>这件事做成了 cross-attention：</p>

<ul>
  <li>文本经过 CLIP text encoder 变成 <code class="language-plaintext highlighter-rouge">[N_text, D]</code> 的 token 序列；</li>
  <li>U-Net 中的图像 feature 作为 query，文本 token 作为 key/value；</li>
  <li>在每一层 attention 里，图像 feature 会去”问”文本 token：你想让我画什么？</li>
</ul>

<p>数学上：</p>

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V\]

<ul>
  <li>$Q = W_q \cdot \text{image_feature}$</li>
  <li>$K = W_k \cdot \text{text_token}$</li>
  <li>$V = W_v \cdot \text{text_token}$</li>
</ul>

<p>这种设计天然适合<strong>变长、异构</strong>的条件：文本可以是 5 个词也可以是 50 个词，attention 自适应处理。</p>

<h3 id="32-不只是文本音频video-embedding-都能这么干">3.2 不只是文本：音频、video embedding 都能这么干</h3>

<p>LatentSync 用 Whisper 提取音频 embedding，然后通过 cross-attention 注入 U-Net；Wan2.1 用 T5 编码文本、CLIP 编码图像，两路条件都通过 cross-attention 进入 DiT。这套机制已经成了”高层语义条件”的事实标准。</p>

<h3 id="33-局限">3.3 局限</h3>

<p>cross-attention 不便宜。原始 DiT 论文实验过几种条件注入方案，发现 cross-attention 比 adaLN 多大约 <strong>15% 的 FLOPs</strong><sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup>。当模型规模上到几十亿参数时，这个开销不容忽视。</p>

<p>更深层的问题是：cross-attention 的条件信息<strong>只在 attention 层生效</strong>，对 LayerNorm、MLP 这些非 attention 层是”透明”的。如果条件本身是一个全局信号（比如”现在是 timestep=500”、”这是猫这一类”），用 cross-attention 杀鸡用牛刀。</p>

<hr />

<h2 id="4-第三代自适应归一化-film--adaln">4. 第三代：自适应归一化 (FiLM / adaLN)</h2>

<h3 id="41-动机有些条件其实是全局调制信号">4.1 动机：有些条件其实是”全局调制信号”</h3>

<p>看一下 cross-attention 在做什么：它让每个图像 token 去关注一些条件 token。但如果条件本身就是一个<strong>全局向量</strong>（比如时间步 <code class="language-plaintext highlighter-rouge">t</code>、类别 <code class="language-plaintext highlighter-rouge">c</code>、说话人 id），让每个图像 token 都去 attend 它，纯粹是浪费算力。</p>

<p>更高效的做法是 <strong>FiLM (Feature-wise Linear Modulation)</strong>：</p>

\[\text{FiLM}(x, c) = \gamma(c) \odot x + \beta(c)\]

<p>也就是用条件 <code class="language-plaintext highlighter-rouge">c</code> 算出一组缩放和偏移参数，直接对 feature 做线性调制。</p>

<h3 id="42-adaln把-film-用到-layernorm-上">4.2 adaLN：把 FiLM 用到 LayerNorm 上</h3>

<p>普通的 LayerNorm 是这样：</p>

\[\text{LN}(x) = \gamma \cdot \text{normalize}(x) + \beta\]

<p>这里的 <code class="language-plaintext highlighter-rouge">γ, β</code> 是学习出来的固定参数。<strong>adaLN (Adaptive LayerNorm)</strong> 把它换成条件的函数：</p>

\[\text{adaLN}(x, c) = \gamma(c) \cdot \text{normalize}(x) + \beta(c)\]

<p>其中 <code class="language-plaintext highlighter-rouge">γ(c), β(c) = MLP(c)</code>。</p>

<p>DiT 的论文系统比较了几种条件注入方案——in-context conditioning、cross-attention、adaLN——最终发现 <strong>adaLN 是 FLOPs 最低、效果最好的方案</strong><sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup>。</p>

<h3 id="43-adaln-zero让深层-transformer-训练得更稳">4.3 adaLN-Zero：让深层 Transformer 训练得更稳</h3>

<p>DiT 在 adaLN 上又加了一个改动，叫 <strong>adaLN-Zero</strong>：除了 scale 和 shift，还预测一个 <strong>gate</strong>，并把它的初始值设为 0：</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">shift</span><span class="p">,</span> <span class="n">scale</span><span class="p">,</span> <span class="n">gate</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">(</span><span class="n">c</span><span class="p">).</span><span class="n">chunk</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

<span class="n">y</span> <span class="o">=</span> <span class="n">LN</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">scale</span><span class="p">)</span> <span class="o">+</span> <span class="n">shift</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">gate</span> <span class="o">*</span> <span class="n">Attention</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>   <span class="c1"># gate 初始为 0，所以这一项一开始是 0
</span></code></pre></div></div>

<p>这意味着：<strong>训练初始时，每个 block 都近似 identity function</strong>，新加进来的网络层不会立即破坏预训练的表征。这个技巧让深层 DiT 训练得非常稳，几乎成了今天所有 Transformer-based diffusion 的标配。</p>

<h3 id="44-局限">4.4 局限</h3>

<p>adaLN 的本质是<strong>全局调制</strong>：它对所有 token 施加同一组 <code class="language-plaintext highlighter-rouge">(γ, β, gate)</code>。这意味着它擅长做的事情是：</p>

<ul>
  <li>时间步注入（<code class="language-plaintext highlighter-rouge">t</code> 是一个标量）</li>
  <li>类别注入（class id 是一个 one-hot）</li>
  <li>全局风格控制（style embedding）</li>
</ul>

<p>它不擅长的事情是：</p>

<ul>
  <li>告诉模型”左手在这里”</li>
  <li>告诉模型”这条边缘要保留”</li>
</ul>

<p>要做空间结构控制，还需要更专门的机制。</p>

<hr />

<h2 id="5-第四代controlnet--给-u-net-外挂一支控制分支">5. 第四代：ControlNet —— 给 U-Net 外挂一支控制分支</h2>

<h3 id="51-动机文本说不清楚姿态">5.1 动机：文本说不清楚”姿态”</h3>

<p>文字 prompt 有一个根本局限：<strong>它没法精确描述空间结构</strong>。你说 “a person standing with left hand raised”，模型可能给你一个右手抬起来的人，或者两只手都抬着的人。</p>

<p>但如果你能给模型一张姿态骨架图、一张边缘图、一张深度图，事情就完全不一样了 —— 这些条件本身就是<strong>空间对齐</strong>的，每个像素都精确地告诉模型”这个位置应该是什么”。</p>

<h3 id="52-做法复制一份-encoder--零卷积">5.2 做法：复制一份 encoder + 零卷积</h3>

<p>ControlNet 的设计很巧妙<sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">3</a></sup>：</p>

<ol>
  <li><strong>冻结</strong>原始 Stable Diffusion U-Net，保留它所有的预训练能力；</li>
  <li><strong>复制</strong>一份 U-Net 的 encoder（包括 down blocks 和 mid block）作为 trainable branch；</li>
  <li>condition 图（pose / edge / depth）作为这条分支的输入；</li>
  <li>这条分支在每个尺度产生 residual feature，<strong>通过零卷积加到原 U-Net 的对应层</strong>。</li>
</ol>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>condition_map (pose/edge/depth)
    ↓
[ControlNet branch (trainable copy of encoder)]
    ↓
multi-scale residuals
    ↓
加到 frozen U-Net 的 down/mid blocks
</code></pre></div></div>

<h3 id="53-零卷积让训练从零干扰开始">5.3 零卷积：让训练从”零干扰”开始</h3>

<p>ControlNet 最关键的技巧是 <strong>zero convolution</strong>：连接 ControlNet 分支和原 U-Net 的卷积层，<strong>初始权重和 bias 都是 0</strong>。</p>

<p>这意味着：</p>

<ul>
  <li>训练第 0 步：ControlNet 输出全是 0，原 U-Net 行为完全不变；</li>
  <li>随着训练进行，零卷积逐渐学到非零参数，ControlNet 的影响逐渐增强；</li>
  <li><strong>不会有训练初期”噪声破坏预训练模型”的问题</strong><sup id="fnref:3:1" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">3</a></sup>。</li>
</ul>

<p>这个设计让 ControlNet 在很小的数据量上就能训练成功 —— 不像 fine-tuning 那样需要担心 catastrophic forgetting。</p>

<h3 id="54-局限">5.4 局限</h3>

<ul>
  <li><strong>每种 condition 要训一个 ControlNet</strong>：pose、edge、depth、normal、segmentation 各自一个分支；</li>
  <li><strong>额外参数量不小</strong>：复制了一半 U-Net，参数量大约是原模型的 50%；</li>
  <li><strong>只控制结构，不控制语义/身份</strong>：ControlNet 告诉模型”这个位置有边缘”，但没法告诉它”这个人长得像谁”。</li>
</ul>

<p>最后这一点引出了下一个机制。</p>

<hr />

<h2 id="6-第五代ip-adapter--让图片成为-prompt">6. 第五代：IP-Adapter —— 让图片成为 prompt</h2>

<h3 id="61-动机有些东西文字真的描述不了">6.1 动机：有些东西文字真的描述不了</h3>

<p>试着用文字描述一个具体的人长什么样：</p>

<blockquote>
  <p>“a woman with long brown hair, brown eyes, oval face, slightly upturned nose…”</p>
</blockquote>

<p>写得再细，模型也只能给你一个<strong>满足这些属性的随机面孔</strong>，不是你脑子里那个<strong>具体的人</strong>。</p>

<p>类似地：</p>

<ul>
  <li>一个 logo 的精确视觉风格；</li>
  <li>一件衣服的具体花纹；</li>
  <li>一种独特的画风。</li>
</ul>

<p>这些”视觉概念”用文字 prompt 几乎不可能精确传达。但如果你能直接给模型一张<strong>参考图</strong>，事情就简单多了。</p>

<h3 id="62-朴素做法把图像和文字-token-拼起来">6.2 朴素做法：把图像和文字 token 拼起来</h3>

<p>最直观的想法是：用 CLIP image encoder 编码参考图，然后把图像 token 和文本 token 拼到一起，喂给同一个 cross-attention。</p>

<p>但 IP-Adapter 论文指出，<strong>这种朴素拼接会导致图像信息被文本特征”覆盖”</strong><sup id="fnref:4" role="doc-noteref"><a href="#fn:4" class="footnote" rel="footnote">4</a></sup>。原因是：原模型的 cross-attention <code class="language-plaintext highlighter-rouge">K, V</code> projection 是针对文本特征训练的，强行让图像特征通过同一组 projection 进入 attention，会损失大量图像信息。</p>

<h3 id="63-解耦交叉注意力-decoupled-cross-attention">6.3 解耦交叉注意力 (Decoupled Cross-Attention)</h3>

<p>IP-Adapter 的核心创新是 <strong>decoupled cross-attention</strong>：<strong>给图像单独开一套 cross-attention</strong>，而不是和文本共享。</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">ip_adapter_attention</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">text_tokens</span><span class="p">,</span> <span class="n">image_tokens</span><span class="p">):</span>
    <span class="n">q</span> <span class="o">=</span> <span class="n">Wq</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

    <span class="c1"># 原本的 text cross-attention（保持不变）
</span>    <span class="n">k_text</span> <span class="o">=</span> <span class="n">Wk_text</span><span class="p">(</span><span class="n">text_tokens</span><span class="p">)</span>
    <span class="n">v_text</span> <span class="o">=</span> <span class="n">Wv_text</span><span class="p">(</span><span class="n">text_tokens</span><span class="p">)</span>
    <span class="n">out_text</span> <span class="o">=</span> <span class="n">attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k_text</span><span class="p">,</span> <span class="n">v_text</span><span class="p">)</span>

    <span class="c1"># 新增的 image cross-attention（新训练的）
</span>    <span class="n">k_img</span> <span class="o">=</span> <span class="n">Wk_img</span><span class="p">(</span><span class="n">image_tokens</span><span class="p">)</span>
    <span class="n">v_img</span> <span class="o">=</span> <span class="n">Wv_img</span><span class="p">(</span><span class="n">image_tokens</span><span class="p">)</span>
    <span class="n">out_img</span> <span class="o">=</span> <span class="n">attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k_img</span><span class="p">,</span> <span class="n">v_img</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">out_text</span> <span class="o">+</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">out_img</span>
</code></pre></div></div>

<p>注意几个关键设计：</p>

<ol>
  <li><strong>复用 query</strong>：图像 attention 和文本 attention 共享同一个 <code class="language-plaintext highlighter-rouge">Q</code>，因为 query 来自图像 latent；</li>
  <li><strong>独立的 K/V projection</strong>：<code class="language-plaintext highlighter-rouge">Wk_img, Wv_img</code> 是新训练的，专门为图像特征设计；</li>
  <li><strong>加和融合</strong>：两路 attention 的输出直接相加，可以用 <code class="language-plaintext highlighter-rouge">scale</code> 控制图像 prompt 的强度。</li>
</ol>

<h3 id="64-优势极小的可训练参数">6.4 优势：极小的可训练参数</h3>

<p>IP-Adapter 论文报告，<strong>只需要约 22M 可训练参数</strong>，就能在冻结的 Stable Diffusion 上实现强大的图像 prompt 能力，并且和文本 prompt、ControlNet 等工具完全兼容<sup id="fnref:4:1" role="doc-noteref"><a href="#fn:4" class="footnote" rel="footnote">4</a></sup>。</p>

<h3 id="65-ip-adapter-vs-ref_latent-latentsync">6.5 IP-Adapter vs ref_latent (LatentSync)</h3>

<p>值得对比一下，因为它们看起来都是”用参考图做条件”：</p>

<table>
  <thead>
    <tr>
      <th>方式</th>
      <th>信息类型</th>
      <th>注入方式</th>
      <th>优点</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><strong>LatentSync 的 ref_latents</strong></td>
      <td>低层像素/纹理/位置</td>
      <td>VAE latent → channel concat</td>
      <td>空间对齐强，重建保真度高</td>
    </tr>
    <tr>
      <td><strong>IP-Adapter 的 image tokens</strong></td>
      <td>高层语义/身份/风格</td>
      <td>CLIP image encoder → cross-attention</td>
      <td>泛化强，可以”风格迁移”</td>
    </tr>
  </tbody>
</table>

<p>简单来说：</p>
<ul>
  <li>想保留<strong>精确的像素结构</strong>（同一个人、同一个场景的不同角度）→ ref latent concat；</li>
  <li>想保留<strong>风格和身份语义</strong>（这个人的”长相”，但姿势可以不一样）→ IP-Adapter。</li>
</ul>

<hr />

<h2 id="7-一张表把前五代串起来">7. 一张表把前五代串起来</h2>

<p>到这里我们已经看了五种主流的条件注入方式。它们其实是<strong>互补</strong>而不是替代关系。一个现代的 Diffusion 系统通常会同时用到好几种：</p>

<table>
  <thead>
    <tr>
      <th>机制</th>
      <th>数学形式</th>
      <th>条件类型</th>
      <th>空间对齐</th>
      <th>参数量</th>
      <th>典型用途</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><strong>Channel Concat</strong></td>
      <td><code class="language-plaintext highlighter-rouge">concat([z_t, c], dim=1)</code></td>
      <td>空间对齐的低层视觉信息</td>
      <td>强</td>
      <td>几乎为 0</td>
      <td>inpainting, ref frame, mask</td>
    </tr>
    <tr>
      <td><strong>Cross-Attention</strong></td>
      <td><code class="language-plaintext highlighter-rouge">Attn(Q=z_t, K=c, V=c)</code></td>
      <td>变长高层语义</td>
      <td>弱</td>
      <td>中</td>
      <td>text, audio, video tokens</td>
    </tr>
    <tr>
      <td><strong>adaLN-Zero</strong></td>
      <td><code class="language-plaintext highlighter-rouge">x = x + gate(c) · f(scale(c)·LN(x)+shift(c))</code></td>
      <td>全局信号</td>
      <td>无</td>
      <td>极小</td>
      <td>timestep, class, style</td>
    </tr>
    <tr>
      <td><strong>ControlNet</strong></td>
      <td><code class="language-plaintext highlighter-rouge">h_i ← h_i + ControlNet_i(cond)</code></td>
      <td>空间结构图</td>
      <td>强</td>
      <td>大（~50% U-Net）</td>
      <td>pose, edge, depth, seg</td>
    </tr>
    <tr>
      <td><strong>IP-Adapter</strong></td>
      <td><code class="language-plaintext highlighter-rouge">Attn_text + λ · Attn_image</code></td>
      <td>参考图语义/身份</td>
      <td>中</td>
      <td>极小（~22M）</td>
      <td>reference image, style</td>
    </tr>
  </tbody>
</table>

<p>注意一个有意思的事实：<strong>这五种机制里，真正被 diffusion 去噪的只有 <code class="language-plaintext highlighter-rouge">z_t</code> 一个</strong>。其他所有的”条件” —— 不管是 mask、文本 token、ControlNet residual、image token —— 都只是<strong>改变 U-Net 对 <code class="language-plaintext highlighter-rouge">z_t</code> 噪声预测方向的指引</strong>，它们自己不会被更新。</p>

<p>理解这一点，对接下来看 DiT 的演进非常重要。</p>

<hr />

<h2 id="8-第六代从-u-net-到-dit条件注入也跟着进化">8. 第六代：从 U-Net 到 DiT，条件注入也跟着进化</h2>

<h3 id="81-为什么大家开始用-transformer-替代-u-net">8.1 为什么大家开始用 Transformer 替代 U-Net</h3>

<p>U-Net 的核心是卷积 + skip connection，它的归纳偏置很适合处理图像，但它有几个问题：</p>

<ol>
  <li><strong>难以 scale</strong>：当模型规模上到 10B 以上，U-Net 的训练不如 Transformer 稳定；</li>
  <li><strong>跨模态融合受限</strong>：U-Net 主要靠 cross-attention 注入文本/图像条件，深度有限；</li>
  <li><strong>不适合统一多模态</strong>：当条件不只是文本、还包括图像、音频、视频时，U-Net 的结构不够灵活。</li>
</ol>

<p>DiT (Diffusion Transformer) 解决了前两个问题：把 U-Net 换成 ViT 风格的 Transformer，把图像 patchify 成 token 序列，然后用 Transformer block 去噪<sup id="fnref:2:1" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup>。</p>

<p>但<strong>条件怎么注入到 DiT 里</strong>，又出现了不同的设计哲学。这就引出了今天最热的两条路线：<strong>多流 MMDiT</strong> vs <strong>单流 S3-DiT</strong>。</p>

<h3 id="82-多流-mmdit文本和图像各走一条流">8.2 多流 MMDiT：文本和图像各走一条流</h3>

<p>代表作是阿里的 <strong>Qwen-Image</strong><sup id="fnref:5" role="doc-noteref"><a href="#fn:5" class="footnote" rel="footnote">5</a></sup>。它是一个 20B 参数的多模态 Diffusion Transformer，核心架构包括三个组件：</p>

<ol>
  <li><strong>冻结的 Qwen2.5-VL</strong>（视觉语言模型）：负责文本和图像的语义对齐；</li>
  <li><strong>VAE encoder/decoder</strong>：负责图像潜变量的压缩和重建；</li>
  <li><strong>MMDiT diffusion backbone</strong>：负责在 latent 空间去噪。</li>
</ol>

<p>它的”多流”体现在：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>text prompt ─→ Qwen2.5-VL ─→ semantic tokens ─┐
                                              ├─→ MMDiT (cross-modal fusion) ─→ noise
input image ─→ VAE ───────→ reconstruction tokens ┐
              + Qwen2.5-VL → semantic tokens   ──┘
              ↑
              这就是 "dual encoding" 机制
</code></pre></div></div>

<p><strong>Dual encoding 是 Qwen-Image 的关键创新</strong>：同一张输入图同时被 Qwen2.5-VL 和 VAE 编码，前者提供语义信息，后者提供重建信息，两者在 MMDiT 里融合。这种设计在图像编辑任务上特别有用 —— 编辑既要”理解你想改什么”（语义），也要”保持原图其他地方不变”（重建保真度）<sup id="fnref:5:1" role="doc-noteref"><a href="#fn:5" class="footnote" rel="footnote">5</a></sup>。</p>

<h3 id="83-单流-s3-dit所有-token-拼成一条序列">8.3 单流 S3-DiT：所有 token 拼成一条序列</h3>

<p>代表作是 <strong>Z-Image</strong>，由 Tongyi Lab 提出的 6B 参数 Scalable Single-Stream Diffusion Transformer (<strong>S3-DiT</strong>)<sup id="fnref:6" role="doc-noteref"><a href="#fn:6" class="footnote" rel="footnote">6</a></sup>。</p>

<p>它的设计哲学和 MMDiT 完全相反：<strong>所有模态的 token 都拼到一条序列里，共享同一个 Transformer 处理</strong>：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[text tokens | visual semantic tokens | noisy image VAE tokens]
                         ↓
                Single Transformer (S3-DiT)
                         ↓
             只取 image token 部分预测 noise
</code></pre></div></div>

<p>S3-DiT 的几个关键技术点：</p>

<ol>
  <li><strong>3D RoPE</strong>：把文本、空间、通道位置都编码到同一空间；</li>
  <li><strong>轻量的模态 stem</strong>：每种模态有自己的小 MLP，把它映射到共享的 hidden space；</li>
  <li><strong>FiLM 式条件适配器</strong>：timestep 和 global condition 通过 FiLM-like scale/shift 注入；</li>
  <li><strong>流匹配 + 蒸馏</strong>：用 flow matching loss 训练，并通过蒸馏得到 Z-Image-Turbo，可以在消费级 GPU 上做亚秒级推理<sup id="fnref:6:1" role="doc-noteref"><a href="#fn:6" class="footnote" rel="footnote">6</a></sup>。</li>
</ol>

<p><strong>为什么 6B 单流能挑战 20B 多流？</strong> 因为单流架构的参数效率更高 —— 文本和图像共享同一套 Transformer 权重，避免了双流之间的冗余表征。</p>

<h3 id="84-两种路线的对比">8.4 两种路线的对比</h3>

<table>
  <thead>
    <tr>
      <th>维度</th>
      <th>多流 MMDiT (Qwen-Image)</th>
      <th>单流 S3-DiT (Z-Image)</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><strong>token 组织</strong></td>
      <td>文本流 + 图像流，分开处理后融合</td>
      <td>所有 token 拼成统一序列</td>
    </tr>
    <tr>
      <td><strong>条件注入</strong></td>
      <td>cross-modal fusion + dual encoding</td>
      <td>unified self-attention + FiLM</td>
    </tr>
    <tr>
      <td><strong>控制风格</strong></td>
      <td>显式、模块化、可分解</td>
      <td>隐式、统一、端到端</td>
    </tr>
    <tr>
      <td><strong>优势</strong></td>
      <td>强语义、强编辑、复杂条件稳定</td>
      <td>参数效率高、结构简洁、推理友好</td>
    </tr>
    <tr>
      <td><strong>劣势</strong></td>
      <td>参数巨大、计算重、系统复杂</td>
      <td>长序列 attention 成本高，强空间控制需额外设计</td>
    </tr>
    <tr>
      <td><strong>典型场景</strong></td>
      <td>专业图像编辑、文字渲染、多条件控制</td>
      <td>高效 T2I、统一多模态生成、轻量部署</td>
    </tr>
  </tbody>
</table>

<h3 id="85-直观比喻">8.5 直观比喻</h3>

<p>如果把 Diffusion 模型比作一个画师：</p>

<ul>
  <li><strong>多流 MMDiT</strong> 像是一个团队：一个文本理解专家、一个图像理解专家、一个绘画师傅，三人开会沟通后画师傅落笔。每个人都是该领域的专家，分工明确，但维护成本高。</li>
  <li><strong>单流 S3-DiT</strong> 像是一个全能选手：他自己读 prompt、看参考图、想空间结构、动笔作画，全在一个脑子里完成。沟通成本低、效率高，但需要训练数据足够多样才能学会所有这些技能。</li>
</ul>

<hr />

<h2 id="9-一个例子图像编辑任务下两种路线怎么处理">9. 一个例子：图像编辑任务下两种路线怎么处理</h2>

<p>假设任务是：</p>

<blockquote>
  <p>“输入一张人像，把衣服改成红色西装，脸和背景保持不变。”</p>
</blockquote>

<p><strong>多流 MMDiT 的做法</strong>：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>原图          ─→ Qwen2.5-VL ─→ "这是一个穿白衬衫的男人" (语义)
              ─→ VAE        ─→ pixel-level 重建特征
文本指令       ─→ Qwen2.5-VL ─→ "把衣服改成红色西装" (编辑指令)
noisy target  ─→ MMDiT denoising stream

MMDiT 通过显式的 cross-modal fusion：
- 语义层面理解"衣服→红色西装"；
- 重建层面知道"脸和背景保持原样"；
- denoising 层面生成最终结果。
</code></pre></div></div>

<p><strong>单流 S3-DiT 的做法</strong>：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[ text instruction tokens
| input image semantic tokens
| input image VAE tokens
| noisy target tokens ]
                ↓
        Single Transformer
                ↓
        从 target tokens 输出 noise
</code></pre></div></div>

<p>模型自己通过 attention 学习”哪些 token 对哪些区域重要”。这种端到端的学习理论上更灵活，但需要大量精心标注的编辑数据才能训稳。</p>

<hr />

<h2 id="10-未来趋势混合路线--多模态统一">10. 未来趋势：混合路线 + 多模态统一</h2>

<p>看完前面这一切，我对未来 Diffusion 架构的演进有几个判断。</p>

<h3 id="101-单流会成为基础架构但不会纯单流">10.1 单流会成为基础架构，但不会”纯单流”</h3>

<p>单流的优势是简洁和效率，但它在<strong>强空间控制</strong>（pose、depth、edge）和<strong>复杂编辑</strong>上还有差距。最实用的系统大概率是 hybrid：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Single-stream DiT backbone (主干)
    + Control branch (空间结构控制，类似 ControlNet)
    + Reference adapter (参考图，类似 IP-Adapter)
    + Mask/edit branch (区域级编辑)
    + Layout/typography expert (排版、文字渲染)
</code></pre></div></div>

<p>也就是说，<strong>单流解决”统一和效率”，专门分支解决”精确和可控”</strong>。</p>

<h3 id="102-控制粒度从-prompt-level-走向-region-level">10.2 控制粒度从 prompt-level 走向 region-level</h3>

<p>现在的主流是”一句 prompt 控制整张图”。但实际场景中，用户更需要：</p>

<ul>
  <li>这个区域保持不变；</li>
  <li>这个物体换材质；</li>
  <li>这个人保持身份；</li>
  <li>这几个字必须准确显示；</li>
  <li>这个 logo 不能变形。</li>
</ul>

<p>这要求模型支持<strong>区域绑定、对象绑定、文字位置绑定、多参考图绑定</strong>。Qwen-Image 强调的”复杂文字渲染”和”精确图像编辑”已经在往这个方向走。</p>

<h3 id="103-生成与编辑会统一为一个模型">10.3 生成与编辑会统一为一个模型</h3>

<p>以前模型经常分得很细：T2I model、inpainting model、editing model、ControlNet model……。未来会逐渐合并：</p>

<blockquote>
  <p>一个模型同时支持 T2I、image editing、multi-image composition、style transfer、layout generation、text rendering、object replacement。</p>
</blockquote>

<p>Qwen-Image 的 multi-task training 已经包括 T2I、TI2I (text+image→image)、I2I reconstruction<sup id="fnref:5:2" role="doc-noteref"><a href="#fn:5" class="footnote" rel="footnote">5</a></sup>。这说明主流方向已经是<strong>统一训练</strong>，而不是每个任务单独一个模型。</p>

<h3 id="104-高效化会变成核心竞争力">10.4 高效化会变成核心竞争力</h3>

<p>Z-Image 的 6B 单流路线说明：<strong>不是只能靠堆参数</strong>。数据质量、架构设计、蒸馏和推理优化同样重要。未来会越来越重视：</p>

<ul>
  <li>few-step generation（4 步、2 步、甚至 1 步采样）；</li>
  <li>distillation（把大模型的能力蒸馏到小模型）；</li>
  <li>FP8 / INT8 / NF4 quantization；</li>
  <li>KV cache / feature cache；</li>
  <li>MoE DiT（稀疏激活的 Diffusion Transformer）；</li>
  <li>consumer GPU fine-tuning。</li>
</ul>

<h3 id="105-跨模态联合学习将成为标配">10.5 跨模态联合学习将成为标配</h3>

<p>未来的”条件”不会只有文本和图像。视频、音频、3D、甚至 IMU/sensor 数据都会一起训练。模型需要更复杂的条件注入策略来处理多模态信息 —— adaLN、cross-attention、ControlNet-like branch、IP-Adapter-like adapter 都会在不同位置发挥作用。</p>

<hr />

<h2 id="11-总结">11. 总结</h2>

<p>回到一开始那个问题：<strong>怎么把”用户想要什么”告诉模型？</strong></p>

<p>过去几年，Diffusion 社区给出的答案大致是这样一条主线：</p>

<ol>
  <li><strong>Channel Concat</strong>：条件和图像在同一空间结构 → 直接拼通道。简单粗暴，适合 inpainting。</li>
  <li><strong>Cross-Attention</strong>：条件是变长语义 → 让图像 token 去 attend 它。文本/音频条件的标配。</li>
  <li><strong>adaLN-Zero</strong>：条件是全局信号 → 用它生成 LayerNorm 的 scale/shift/gate。极低开销，DiT 标配。</li>
  <li><strong>ControlNet</strong>：条件是空间结构图 → 复制一份 encoder + 零卷积。强空间控制，但每种条件要训一个分支。</li>
  <li><strong>IP-Adapter</strong>：条件是参考图 → 解耦 cross-attention，给图像单开一套。极小参数实现图像 prompt。</li>
  <li><strong>多流 MMDiT (Qwen-Image)</strong>：文本和图像各走一条流，通过 cross-modal fusion 融合。强编辑、强语义。</li>
  <li><strong>单流 S3-DiT (Z-Image)</strong>：所有 token 拼成一条序列共享 Transformer。参数高效、推理友好。</li>
</ol>

<p><strong>核心洞察</strong>：这些机制不是替代关系，而是互补的。一个现代 Diffusion 系统往往同时用到多种 —— DiT 主干用 adaLN-Zero 做时间步调制，cross-attention 接文本，ControlNet 做空间控制，IP-Adapter 做参考图。理解每种机制的”擅长什么、不擅长什么”，比死记某一个架构更有用。</p>

<p>未来的方向是<strong>单流主干 + 多种专门分支的 hybrid 架构</strong>，再叠加蒸馏、量化和高效推理。从条件注入的角度看，这场演化远没有结束。</p>

<hr />

<h2 id="sources">Sources</h2>

<ul>
  <li><a href="https://arxiv.org/abs/2006.11239">DDPM (Ho et al., 2020)</a></li>
  <li><a href="https://github.com/bytedance/LatentSync">LatentSync (字节, 2024)</a></li>
  <li><a href="https://arxiv.org/abs/2112.10752">Stable Diffusion / Latent Diffusion (Rombach et al., 2022)</a></li>
  <li><a href="https://arxiv.org/abs/2212.09748">DiT: Scalable Diffusion Models with Transformers (Peebles &amp; Xie, 2022)</a></li>
  <li><a href="https://arxiv.org/abs/2302.05543">ControlNet (Zhang et al., 2023)</a></li>
  <li><a href="https://arxiv.org/abs/2308.06721">IP-Adapter (Ye et al., 2023)</a></li>
  <li><a href="https://github.com/QwenLM/Qwen-Image">Qwen-Image 官方仓库</a></li>
  <li><a href="https://arxiv.org/abs/2508.02324">Qwen-Image Technical Report</a></li>
  <li><a href="https://github.com/Tongyi-MAI/Z-Image">Z-Image / S3-DiT</a></li>
  <li><a href="https://arxiv.org/abs/1709.07871">FiLM (Perez et al., 2017)</a></li>
  <li><a href="https://arxiv.org/abs/2210.02747">Flow Matching for Generative Modeling (Lipman et al., 2023)</a></li>
</ul>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:1" role="doc-endnote">
      <p>DiT 论文中比较了 in-context conditioning、cross-attention、adaLN 三种条件注入方式，cross-attention 比 adaLN 多约 15% Gflops。 <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:2" role="doc-endnote">
      <p>Peebles &amp; Xie. <em>Scalable Diffusion Models with Transformers</em>. ICCV 2023. adaLN-Zero 把 residual block 中调制参数初始化为零，使 block 初始接近 identity function。 <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:2:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p>
    </li>
    <li id="fn:3" role="doc-endnote">
      <p>Zhang et al. <em>Adding Conditional Control to Text-to-Image Diffusion Models</em>. ICCV 2023. ControlNet 通过 zero-initialized convolution 让参数从零逐渐增长，避免训练初期破坏预训练模型。 <a href="#fnref:3" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:3:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p>
    </li>
    <li id="fn:4" role="doc-endnote">
      <p>Ye et al. <em>IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models</em>. 2023. 关键设计是 decoupled cross-attention，把 text 和 image 的 cross-attention 分开。 <a href="#fnref:4" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:4:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p>
    </li>
    <li id="fn:5" role="doc-endnote">
      <p>Qwen-Image 是 20B MMDiT 模型，使用 Qwen2.5-VL + VAE 双编码机制，支持 T2I、TI2I、I2I 多任务训练，强调复杂文字渲染和图像编辑能力。 <a href="#fnref:5" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:5:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a> <a href="#fnref:5:2" class="reversefootnote" role="doc-backlink">&#8617;<sup>3</sup></a></p>
    </li>
    <li id="fn:6" role="doc-endnote">
      <p>Z-Image 是 6B 参数的 Scalable Single-Stream Diffusion Transformer (S3-DiT)，把 text、visual semantic tokens、image VAE tokens 拼成统一输入流，通过 3D RoPE 和 FiLM 适配器注入条件，支持亚秒级推理。 <a href="#fnref:6" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:6:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p>
    </li>
  </ol>
</div>]]></content><author><name>李勇志 (Yongzhi Li)</name><email>yongzhili@pku.edu.cn</email></author><category term="blog" /><category term="diffusion" /><category term="generative model" /><category term="DiT" /><category term="ControlNet" /><category term="IP-Adapter" /><category term="multimodal" /><summary type="html"><![CDATA[如果你看过 Stable Diffusion、ControlNet、IP-Adapter，又听说过最近的 Qwen-Image 和 Z-Image，可能会有一个共同的疑问：]]></summary></entry><entry><title type="html">Agentic RL 训练全景：环境、信号、分布与系统的协同闭环</title><link href="https://liyongzhi.xyz/posts/2026/04/agentic-rl/" rel="alternate" type="text/html" title="Agentic RL 训练全景：环境、信号、分布与系统的协同闭环" /><published>2026-04-28T00:00:00+08:00</published><updated>2026-04-28T00:00:00+08:00</updated><id>https://liyongzhi.xyz/posts/2026/04/blog-post-agentic-rl</id><content type="html" xml:base="https://liyongzhi.xyz/posts/2026/04/agentic-rl/"><![CDATA[<blockquote>
  <p>过去一年，各家大模型公司公开的技术报告透出的最重要信号，不是又出现了一个更好的 PPO/GRPO 变体，而是<strong>真正有效的 Agentic RL 已经从”单轮文本优化”转向了”在长上下文、工具调用、部分可观测、异步执行环境中的系统性策略学习”</strong>。</p>

  <p>Kimi K1.5[1] 把长上下文 RL、partial rollout 重用和 mirror-descent 风格的 policy optimization 拉到了台前；Kimi K2[2]/K2.5[3] 又把 agentic 数据合成、多模态 RL、token-level clipping、GRM rubric、Toggle、PARL / Agent Swarm 这些关键部件公开；MiniMax 把另一个事实讲得更彻底：<strong>当 rollout 时长从秒级扩到小时级，训练瓶颈就不再是 loss design，而是吞吐、稳定性与 agent 灵活性之间的三难权衡</strong>；GLM 则强调分阶段 RL：Reasoning RL、Agentic RL、General RL 不是混在一起一次训完，而是通过顺序化 pipeline 逐步推进，并借助异步 RL 基础设施与跨阶段蒸馏来兼顾长时程 agent 学习与能力保持。</p>

  <p>Agentic RL 的核心问题，已经从”怎么更新参数”扩展为”<strong>怎么在真实 Agent 环境里持续制造可用的学习信号，并用在线交互的轨迹数据驱动优化</strong>“。</p>
</blockquote>

<hr />

<h2 id="一为什么-agentic-rl-与传统-rlhf--rlvr-不同">一、为什么 Agentic RL 与传统 RLHF / RLVR 不同</h2>

<p>Agentic RL 的训练对象不再是”给定一个 prompt，输出一个答案”的单轮文本映射，而是<strong>一个在环境中交互的策略</strong>。这个策略要处理：状态更新、工具调用、外部观察、上下文整理、子任务委派、终止条件判断，以及成本 / 时延 / 安全约束。换句话说，agentic RL 更像是在做一类带有长时间尺度、部分可观测性和结构化动作空间的策略学习，而不是简单地对文本续写概率做后验重排。</p>

<p>这直接带来四个训练上的变化：</p>

<ol>
  <li><strong>状态不再只由用户输入决定</strong>：它由历史轨迹、工具返回、环境回馈、记忆摘要和当前上下文共同构成。</li>
  <li><strong>动作也不再只是下一个 token</strong>：它可能是”选哪个工具、填什么参数、要不要压缩上下文、是否并行分派子任务”。</li>
  <li><strong>奖励更延迟、更稀疏、更复合</strong>：既要看结果对不对，也要看过程是否准确、是否高效、是否节省 token 和单位时间有效训练效率。</li>
  <li><strong>Rollout 时间高度不均匀</strong>：同步训练代价高，异步训练又引入分布偏移。</li>
</ol>

<p>因此，agentic RL 的本质<strong>不是把 GRPO / PPO 套到更长的输出上</strong>，而是把环境、奖励、采样、调度、缓存、优化器和评测接到同一个闭环里。</p>

<hr />

<h2 id="二理解-agentic-rl-的三个不变量">二、理解 Agentic RL 的三个不变量</h2>

<p>如果把 Agentic RL 理解成一个”在真实环境里持续交互、持续采样、持续更新”的策略学习系统，那么真正重要的就不再是”这一步用哪种 RL 算法”，而是<strong>训练闭环能否长期守住三个更底层的条件</strong>。</p>

<p>这里的”不变量”不是指某个量在数学上严格恒定，而是指它们虽然会天然漂移，却必须在整个训练过程中被不断拉回到一个仍然可学习、可优化的区间里。前两个是<strong>不应跌破的下限</strong>，第三个是<strong>不应越过的上限</strong>。</p>

<h3 id="1第一不变量策略的可探索空间不能过早塌缩">1）第一不变量：策略的可探索空间不能过早塌缩</h3>

<p>第一不变量<strong>不是</strong>要求输出更随机、token 熵更高，而是：<strong>模型在给定状态下，仍然保有一组彼此可区分、语义上不同、并且真实可行的行为路径</strong>。</p>

<p>对 Agentic RL 来说，这个探索空间不只是”不同措辞”，而是：</p>

<ul>
  <li>不同的<strong>任务分解</strong>方式</li>
  <li>不同的<strong>工具调用顺序</strong></li>
  <li>不同的<strong>记忆读写</strong>策略</li>
  <li>不同的<strong>上下文整理</strong>方式</li>
  <li>不同的<strong>停止条件</strong>与<strong>自我修正</strong>路径</li>
</ul>

<p>它之所以会塌缩，是因为训练天然会把概率质量压向”少数当前最占优的模式”。只要训练目标主要奖励”更短、更像标准流程、更容易被 verifier 识别”的行为，模型就会把其他原本也可能成功的路径边缘化。在 agent 场景下，这种压缩比单轮问答更严重——工具接口、scaffold、上下文模板和终止逻辑本身就会<strong>暗中偏好</strong>某类固定 workflow。</p>

<p><strong>保持这一不变量的意义</strong>：它决定了后续 RL 是否还有真正的搜索空间。RL 的价值不是把已知最好答案重复推高概率，而是让模型在交互中持续发现”此前还没被放大的高回报行为”。如果可探索空间已经提前塌缩，后面的采样大多只是对同一种套路做表面扰动，reward spread 越来越小，训练看似还在继续，实际上只是在一个已经缩水的空间里做局部扰动。</p>

<h3 id="2第二不变量学习信号必须持续非退化">2）第二不变量：学习信号必须持续非退化</h3>

<p>即使模型仍然保有多种可行路径，这些路径也不一定会被<strong>学到</strong>。参数更新依赖的不是”存在别的可能性”，而是<strong>不同轨迹之间的差异能否稳定地转成非零、方向明确、尺度合理的梯度</strong>。</p>

<p>Agentic RL 的奖励结构天然容易让信号塌缩：真实任务奖励延迟、结果稀疏、过程很长，最终常常只有成败标签、粗粒度 rubric，或少数高层质量分。于是同一组采样很容易出现两种退化情形——</p>

<ul>
  <li><strong>简单任务几乎全对</strong>（模型已在该局部饱和）</li>
  <li><strong>困难任务几乎全错</strong>（模型尚未进入可学习区域）</li>
</ul>

<p>但对梯度而言，这两类样本都会导向同一个结果：<strong>组内没有足够差异，优势接近消失，更新方向随之退化</strong>。再叠加长轨迹的信用分配、部分可观测性带来的归因模糊、工具噪声和 verifier 噪声对比较关系的污染——系统表面上在大量收集交互数据，实际上却在不断生产”不可学样本”。</p>

<p><strong>这里有一个关键观察</strong>：学习信号的质量，不取决于奖励项有多少，而取决于<strong>比较是否可学</strong>。奖励可以很复杂，但如果它无法在模型当前边界附近稳定区分”略好”与”略差”的轨迹，它仍然会产生退化梯度。反过来，一个看上去更简单的反馈，只要能持续打开轨迹间的有效差异，也能成为高质量学习信号。</p>

<p><strong>第二不变量真正要求保持不变的，不是奖励总量，而是可比较性与可更新性。</strong></p>

<h3 id="3第三不变量训练--更新--部署三者的分布偏移必须可控">3）第三不变量：训练 / 更新 / 部署三者的分布偏移必须可控</h3>

<p>前两个不变量解决”还有没有别的路径”和”这些路径能不能变成梯度”，第三个不变量解决<strong>这些梯度是不是作用在了正确的分布上</strong>。</p>

<p>在 Agentic RL 中，有三个天然不一致的分布：</p>

<ul>
  <li>策略模型<strong>采样</strong>出的 rollout 分布</li>
  <li>learner 真正<strong>拿来更新</strong>的样本分布</li>
  <li>最终<strong>部署执行</strong>的策略分布</li>
</ul>

<p>Agent 训练会持续制造分布漂移：</p>

<ul>
  <li>轨迹长短差异极大，严格同步的 on-policy 不现实，异步采样、缓存、续跑、复用、过滤都会让”生成样本时的策略”和”更新参数时的策略”发生时间错位；</li>
  <li>Agent 状态由工具返回、环境反馈、上下文裁剪、记忆摘要、调度决策共同构成，只要其中任何一层在 rollout / training / serving 三阶段的表示不完全一致，模型学到的可能就不是同一个动作语义；</li>
  <li>训练和部署脚手架常常并不完全相同：解码设置、context packing、tool schema、tokenizer/engine、middleware、日志序列化方式都会改变模型真正面对的决策问题。</li>
</ul>

<p>结果是：被优化的不再是一个干净统一的策略分布，而是<strong>多个相似但不相同的分布拼接而成的近似对象</strong>。</p>

<p>对长轨迹 Agent，这一点尤其致命——轨迹越长，前面每一点小的偏移都会沿着后续状态转移不断累积，最终把策略推向”在训练里看起来合理、在真实环境里却不可执行”的方向。</p>

<p><strong>Agentic RL 里的分布偏移，不只是外部环境变化带来的，它在很大程度上是系统自己制造出来的</strong>。这也是为什么第三不变量不是单纯的算法修正问题，而是一个系统级的一致性问题。</p>

<h3 id="4为什么这三个不变量要放在一起理解">4）为什么这三个不变量要放在一起理解</h3>

<p>它们不是彼此独立的三条要素，而是<strong>同一个训练系统的三个耦合边界</strong>：</p>

<table>
  <thead>
    <tr>
      <th>不变量</th>
      <th>本质问题</th>
      <th>失守后果</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>第一</td>
      <td>策略空间是否还够宽</td>
      <td>没有可探索的新路径</td>
    </tr>
    <tr>
      <td>第二</td>
      <td>空间里的差异能否转成有效梯度</td>
      <td>有路径但学不到</td>
    </tr>
    <tr>
      <td>第三</td>
      <td>梯度是否作用在正确分布上</td>
      <td>学到的行为在部署时失真</td>
    </tr>
  </tbody>
</table>

<ul>
  <li>只有<strong>探索</strong>没有<strong>信号</strong>：训练变成高噪声试错；</li>
  <li>只有<strong>信号</strong>没有<strong>探索</strong>：训练迅速收缩到狭窄局部最优；</li>
  <li>探索和信号都有，但<strong>分布偏移失控</strong>：学到的也不是部署时真正需要的行为。</li>
</ul>

<p>它们彼此之间还天然存在张力：探索更强，会让比较更稀、分布偏移更难控；过度追求稳定更新，又容易压平探索空间；为了制造更锋利的信号把 verifier 设计得过于严格，又会让模型朝少数投机模式收缩。</p>

<p><strong>Agentic RL 真正要解决的，不是把某个 loss 降得更低，而是在一个持续变化、持续异步、持续与外部环境交互的系统里，始终把探索、信号和分布维持在同一个可学习区间内。</strong></p>

<hr />

<h2 id="三agentic-rl-的九个关键维度">三、Agentic RL 的九个关键维度</h2>

<p>三个不变量是”要守住什么”；下面九个维度是”在哪些具体位置守”。前八个维度对应训练系统的核心环节，第九个维度（评测与可观测性）回答的是一个更基础的问题——<strong>如果你连三个不变量是否正在被守住都测不出来，就根本谈不上管理它们</strong>。</p>

<h3 id="1-环境与接口建模先搞清楚环境允许-agent-做什么再谈正确答案是什么">1. 环境与接口建模：先搞清楚”环境允许 Agent 做什么”，再谈”正确答案是什么”</h3>

<p>Agentic RL 和普通”对一道题生成一个答案”的最大差别在于：模型不再只是从 prompt 里猜一个 completion，而是在一个<strong>可交互、可执行、带状态转移</strong>的世界里学 policy。</p>

<p>决定 Agentic RL 训练效果的<strong>第一个变量</strong>不是 reward model，而是环境和接口本身是否设计清楚：</p>

<ul>
  <li>每一步模型能看到哪些信息？</li>
  <li>能采取哪些动作、哪些工具调用是允许且有效的？</li>
  <li>任务在什么条件下结束、成功如何判断？</li>
  <li>训练时使用的工具接口和交互流程，是否和真实部署一致？</li>
</ul>

<p>当前几家的共识非常一致：</p>

<ul>
  <li>Kimi K2 把大规模 agentic 数据合成 + 真实/合成环境 RL 放进后训练主线；</li>
  <li>K2.5 把 Agentic RL 统一到 <strong>Gym-like 接口</strong>，并支持大规模异步任务管理；</li>
  <li>GLM-5[8] 把 agentic RL 扩展到<strong>超过 10K 个可验证的软件工程环境、terminal 环境和多跳搜索任务</strong>；</li>
  <li>Forge[9] 强调系统跨越了十万级 real-world scaffolds 与数千种工具调用格式。</li>
</ul>

<p>真正的 agentic capability 不是从静态数据里背下来的，而是从<strong>结构化、可验证、可迁移的环境</strong>里训练出来的。</p>

<p>环境建模的核心，不是把现实世界完整模拟出来，而是把真实工作转写成一个<strong>结构上不失真的可训练决策过程</strong>——重要的不是表面真实，而是 <strong>structural fidelity</strong>：动作空间、关键信息流、失败模式和成功判据，是否与真实部署保持一致。举一个典型例子：一个客服 agent 不必复现公司所有噪声，但必须保留库存状态、退款规则、权限边界、上下文记忆、工具接口、升级流程和最终评分 rubric；否则学到的只是”像在做客服”，而不是”真的能做客服”。</p>

<p><strong>环境覆盖度，是 Agentic RL 的第一条 scaling axis</strong>。但真实任务的难点往往不在 data scaling，而在 <strong>specification scaling</strong>：很多高价值任务之所以难进训练闭环，不是因为模型不够聪明，而是因为任务没有被写成机器可执行、机器可验证的规范。下一代 env scaling 更像三个”编译器”问题：</p>

<ul>
  <li><strong>task compiler</strong>：把模糊请求编译成初始状态、工具、约束和终止条件；</li>
  <li><strong>verifier compiler</strong>：把”做得好不好”编译成可执行检查、rubric 和必要时的人类审阅；</li>
  <li><strong>scaffold compiler</strong>：把同一能力放进不同 agent scaffold、tool schema 和 orchestration loop，避免模型只记住单一 workflow。</li>
</ul>

<p>Forge 强调跨大规模 scaffold 训练，本质上就在处理第三个问题。真实人类任务里最大的问题不是”任务太少”，而是 <strong>evaluator 太弱</strong>——一旦 verifier 失真，模型就会学会 hacking，而不是学会工作。SWE-Universe[10] 把环境构建、self-verification 和 hacking detection 自动化，说明大家已经开始把”防投机评测”当成环境的一部分。</p>

<h3 id="2-探索能力与多样性保持不是把-temperature-调高而是维护可探索行为的空间">2. 探索能力与多样性保持：不是把 temperature 调高，而是维护可探索行为的空间</h3>

<p>很多人一谈”探索”就想到：调高 temperature、多采几个 rollout、加 entropy regularization。但对 agentic RL，这些都只是表层现象。核心问题是：<strong>模型在训练的不同阶段，是否仍然保有一组彼此可区分、都可能成功、且在参数空间里真实可达的行为路径</strong>。</p>

<p>对 reasoning 模型，这个问题已经被直接观察到：随着 SFT 推进，Pass@1 可以继续上升，但 <strong>Pass@k 会快速恶化</strong>，而且后续 RL 往往也恢复不了；仅靠 token-level 的多样化解码，距离理论上的 oracle 上界仍有明显差距。<strong>真正塌缩的不是采样温度，而是模型权重层面的行为可探索空间。</strong></p>

<p>所以这一节最本质的思想是：<strong>探索本质上是一个 support management 问题</strong>。你要管理的不是 token 级噪声，而是模型是否还保有：</p>

<ul>
  <li>多种合法任务分解</li>
  <li>多种工具调用顺序</li>
  <li>多种上下文组织方式</li>
  <li>多种长度的 reasoning path</li>
  <li>在 agent 场景下的多种 memory / planning / action 组合</li>
</ul>

<p>只要这些分支在参数里还活着，后续 RL 才有可能通过 verifier 和 rollout 把它们放大；一旦在进入 RL 前就被压没，训练再稳定也只是在缩水的空间里做局部优化。</p>

<p><strong>预训练 / 基座阶段</strong>决定的是 <strong>reachable support</strong>——模型是否已经具备足够多的技能碎片、长上下文耐受性、工具使用先验和任务分解能力：</p>

<ul>
  <li>MiniMax-M1[11] 把额外 7.5T continual pretraining 直接称为 “<em>Foundation for RL Scaling</em>“；</li>
  <li>Kimi K2 用 diverse agents、tool combinations 和 rubric-guided tasks，把未来 agent 可能探索的 action space 和 task space 提前做宽；</li>
  <li>DeepSeek-R1-Zero[12] 提供了另一个很有代表性的例子：它在没有 SFT 冷启动的前提下直接 RL，模型会自然增加思考时长，并逐步长出更长推理和自我修正的行为——这说明对能力足够强的基础模型，<strong>RL 过程本身就可能激发并放大更长程的推理与自我修正行为</strong>。</li>
</ul>

<p><strong>冷启动 / SFT 阶段</strong>真正要解决的，不仅是”把模型教得更会答题”，而是<strong>不要在进入 RL 之前就把分布压塌</strong>：</p>

<ul>
  <li>GEM[4] 的重要性不在于又提出一个新的 SFT loss，而在于它把问题说透了：标准交叉熵 SFT 会压缩输出分布，抹掉很多 alternative plausible outputs，而在线 RL 恰恰需要这些行为分歧来形成探索空间；</li>
  <li>Getting Your LLMs Ready for RL[13] 进一步指出：最适合接 RL 的 checkpoint，<strong>往往不是 validation 上表现最好的那个</strong>——在传统过拟合发生之前，模型就可能已经出现 distributional forgetting，过度偏离 base distribution，从而损害后续 RL 的潜力。</li>
</ul>

<p><strong>到了在线 RL 阶段</strong>，探索问题又会表现成另一种形态：即便模型内部还保留着多种路径，如果 RL 目标只盯 correctness，训练仍然会把概率质量持续推向少数高回报模式：</p>

<ul>
  <li>DAPO[14] 把 Clip-Higher 明确写成 “<em>promotes diversity and avoids entropy collapse</em>“；</li>
  <li>Diversity-Aware Policy Optimization[15] 在 12 个 LLM 上给出更强的经验结论：solution diversity 与 Potential@k 存在强正相关，因此<strong>在 RL 目标中显式促进 token-level diversity</strong>，平均带来 3.5% 的数学推理提升。</li>
</ul>

<p>这里真正重要的不是某一个技巧，而是一个更深的转向：<strong>探索，第一次从”训练自然会保住的东西”，变成了需要被显式优化的对象</strong>。</p>

<p>这一维度今天仍有几个未解决的问题：</p>

<ol>
  <li>当前很多方法管理的仍然是 token entropy 或字符串级 diversity，但 agentic RL 真正需要保住的是<strong>语义层和策略层的多样性</strong>——不同工具顺序、不同 memory 操作、不同任务分解不一定表现为更高的 token entropy；</li>
  <li>很多系统的 verifier 偏 outcome-only，天然低估那些”短期看更绕、长期却更有价值”的探索路径；</li>
  <li>社区仍过度依赖 Pass@1，而对 Pass@k、Potential@k、解法簇数量、跨 scaffold 迁移这些更接近探索前沿的指标重视不够。</li>
</ol>

<h3 id="3-算力分配与学习信号整理谁拿到-rollout谁才真正有机会被学到">3. 算力分配与学习信号整理：谁拿到 rollout，谁才真正有机会被学到</h3>

<p>上一节讨论”多样化的采样路径是否存在”，这一节讨论<strong>在固定 rollout 预算下，这些路径里哪些会真正进入梯度</strong>。探索解决<strong>可达性</strong>，算力分配解决<strong>可学习性</strong>。</p>

<p>对 reasoning / agentic RL 来说，模型内部也许还保留着多种策略，但如果 rollout 总是平均分给”已经学会的简单题”和”暂时完全学不会的极难题”，训练既看不到组内差异，也形成不了有效梯度——在稀疏奖励和 group baseline 设置下，很多 prompt-group 会退化成全 0 或全 1，<strong>advantage energy 为 0，gate 关闭</strong>，这些组消耗了算力却没有产生 usable learning signal。</p>

<p>因此，真正该优化的目标不只是平均 reward，而是更接近训练动力学本身的量：</p>

<ul>
  <li>non-zero gradient ratio</li>
  <li>gate-open probability</li>
  <li>组内 reward spread</li>
  <li>单位训练时间内的有效样本率</li>
</ul>

<p><strong>算力分配是 credit assignment 的上游机制</strong>：谁拿到更多 rollout，谁就更有机会被比较、被区分、被学到。</p>

<p>主流做法可以分成三类——</p>

<p><strong>① 方差控制视角</strong>。既然不同 prompt 对梯度方差的贡献不同，那么 rollout 预算就应该优先投给那些最能减少估计方差、最可能恢复学习信号的 prompt：</p>

<ul>
  <li>GVM-RAFT[17] 从 acceptance rate 和 gradient noise 的角度做动态分配；</li>
  <li>VIP[18] 更系统，用轻量高斯过程预测 prompt 成功概率，再转成 gradient variance 估计，并在固定预算约束下解一个 rollout allocation 优化问题。VIP 明确把目标写成 <em>minimize the expected gradient variance of the policy update</em>，而不是机械拉高 pass rate。</li>
</ul>

<p>这标志着 rollout allocation 开始从经验 heuristics 变成 <strong>policy optimization 的一部分</strong>。</p>

<p><strong>② 学习价值—成本权衡视角</strong>。Knapsack RL[6] 把每个任务的探索看成”具有不同 value 和 cost 的 item”，由此推出自适应资源分配规则——把预算从已经学饱和的题转移到更可能产出信号的题。预算分配不是为了省钱，而是<strong>避免把大量算力烧在注定不会更新参数的地方</strong>。</p>

<p><strong>③ 主动恢复信号视角</strong>。Reinforce-Ada[19] 认为很多”所谓难 prompt”没法学，其实是 undersampling 造成的统计假象，而不是模型真没潜力。于是它不再用固定小组、统一采样被动等待 mixed outcomes，而是<strong>根据 prompt 难度动态增加推理预算</strong>，主动去找出那些本来会被 uniform GRPO 漏掉的信号。</p>

<p>这个话题还有不少未解问题：</p>

<ul>
  <li>现有 allocator 主要依赖 pass rate、variance proxy 或近期 rollout 统计，但这些量不等于<strong>长期训练价值</strong>——一个 prompt 今天方差大，不代表明天最值得更多预算；</li>
  <li>现有方法仍把单条 prompt 作为分配单位，但 agentic RL 的训练难度更多取决于<strong>交互结构和执行状态</strong>（scaffold、工具链、历史记忆、任务阶段），而不只是 prompt 文本；</li>
  <li>大多数分配器优化的是局部训练效率，还没有把预算分配、reward 结构、hinting、off-policy freshness、长时程 credit assignment <strong>联合起来</strong>。</li>
</ul>

<p>下一步真正值得做的是：把 semantic difficulty、uncertainty、verifier sharpness、历史 learning gain、scaffold transfer 价值、甚至 hinting 后的 gate-open probability，<strong>一起纳入 allocation policy</strong>；把 prompt-level allocation 推广到 trajectory segment、tool-call branch、memory operation 这类更细的 agent 单位。到那时，算力分配才会真正从”更高效的训练技巧”变成 <strong>agentic RL 的核心算法层</strong>。</p>

<h3 id="4-目标函数与策略优化不先问用哪种-rl先问现在到底坏在哪">4. 目标函数与策略优化：不先问”用哪种 RL”，先问”现在到底坏在哪”</h3>

<p>这一部分重点不是 PPO、GRPO、REINFORCE 的技术细节，而是<strong>Agentic RL 的优化器究竟在控制什么</strong>。更本质地说，它在回答两个问题：</p>

<ol>
  <li>高回报轨迹要以多大力度被推回当前策略？</li>
  <li>rollout 分布、learner 更新分布、deployment 执行动作之间允许多大偏移？</li>
</ol>

<p>这里有一条<strong>常被忽略的基本事实</strong>：PPO 那一整套 value network machinery 未必是必要的。ReMax[5] 提醒我们，在文本生成这种快仿真、近似确定性转移、轨迹级奖励的设定下，REINFORCE 路线也可以既简单又稳定。Kimi K1.5 则把长 CoT RL 明确写成 <strong>relative-entropy regularized 的 online mirror descent</strong> 问题。</p>

<p>到了 K2.5、MiniMax-M1 和 GLM-5，问题进一步从”如何估 advantage”转成”如何控制长轨迹、异步 rollout、训练 / 推理 mismatch 下的 off-policy drift”，于是出现了这些看起来很细但实际上很关键的设计：</p>

<ul>
  <li><strong>K2.5 的 token-level clipping</strong>：处理 train-inference framework 差异放大的 off-policy divergence；</li>
  <li><strong>M1 的 CISPO</strong>：裁 importance weights 而不是裁 token updates，在保留更多 token 级梯度的同时控制比值爆炸；</li>
  <li><strong>GLM-5 的 TITO + 双边重要性采样</strong>：确保被优化的动作尽可能还是当时真正被采样的动作。</li>
</ul>

<p><strong>未来真正有价值的优化研究，不是继续修改 PPO 或 GRPO 的公式</strong>，而是先诊断：训练当前究竟受限于哪一类瓶颈——</p>

<ul>
  <li>梯度噪声过大？</li>
  <li>策略漂移过快？</li>
  <li>训练目标与真实任务不匹配？</li>
</ul>

<p>只有先定位清楚，才能决定是改进优势估计、采样方式、更新约束，还是训练调度策略。</p>

<h3 id="5-rollout-采样异步并行与调度调度策略本身就是算法的一部分">5. Rollout 采样、异步并行与调度：调度策略本身就是算法的一部分</h3>

<p>在真实 agent 场景，理想化的同步 on-policy RL 很难被满足：不同 rollout 完成时间差异极大，短的几秒，长的可能几十分钟甚至更久。<strong>坚持严格同步会被 straggler 拖死；完全贪心异步又把训练拖入过重的 off-policy 偏移</strong>。</p>

<p>各家给出的折中方案非常有代表性：</p>

<ul>
  <li><strong>Kimi K1.5 的 partial rollout</strong>：长轨迹切段，未完成部分进 replay buffer，下一轮继续，只有当前段要求 on-policy；</li>
  <li><strong>K2.5</strong>：每个 agent task 都当作独立异步 coroutine，通过专门的 Rollout Manager 支持高并发；</li>
  <li><strong>MiniMax 的 Windowed FIFO</strong>：在”严格 FIFO（稳但慢）”和”完全异步（快但漂移大）”之间做折中——不要求全局严格排队，只在有限窗口内保持大致顺序，让窗口里的已完成任务可以灵活先训练；</li>
  <li><strong>GLM-5</strong>：直接把采样和训练分开，一边持续并行生成轨迹，另一边独立消费数据，再用 TITO + 双边重要性采样 + 陈旧样本过滤来控制异步训练中不可避免的 off-policy 偏移。</li>
</ul>

<p><strong>很多人把 queueing、resume、tail-latency、staleness 当成工程问题，但在 agentic RL 里，调度实际上会改写训练分布</strong>。K1.5 的 partial rollout 意味着一条长轨迹由新旧段拼接而成；MiniMax 的 Windowed FIFO 直接控制了”允许新鲜样本先于更早提交的样本进入训练”的程度；GLM-5 的异步 Agent RL 更是明确承认”不现实去追踪所有历史行为策略，必须在可接受的偏差内做近似校正”。</p>

<p>Agentic RL 的核心不是”如何保持纯 on-policy”，而是<strong>如何在不可避免的异步与陈旧性下，让偏移保持在仍然有学习价值的范围内</strong>。这就是为什么 rollout system 不是承载算法的底座——<strong>它本身就是算法的一部分</strong>。</p>

<h3 id="6-奖励验证器与效率约束reward-定义的不只是答对而是怎样工作才算好">6. 奖励、验证器与效率约束：Reward 定义的不只是”答对”，而是”怎样工作才算好”</h3>

<p>很多关于 agentic RL 的讨论会说”verifier 就够了”。这对真实 Agent 任务其实不成立：agent 的成功不只体现在 final correctness 上，还体现在动作是否合理、工具调用是否合适、是否浪费上下文、是否无意义过度思考、是否拖慢总完成时间、以及输出是否符合更高层的质量和交互要求。</p>

<p>几家的具体做法非常有参考价值：</p>

<ul>
  <li><strong>K2.5</strong>：可验证任务用 rule-based outcome reward，token 成本用 budget-control reward，开放任务用多 rubric GRM，并通过 Toggle 在”尽量做对”和”尽量省 token”之间交替优化；</li>
  <li><strong>MiniMax-M1</strong>：verifiable 与 unverifiable 任务分开处理，用 GenRM 处理不能靠规则验证的任务，并<strong>特别讨论了长 CoT 下 GenRM 的 length bias</strong>——奖励模型偏好更长但未必更好的回答，会直接诱发 reward hacking；</li>
  <li><strong>GLM-5</strong>：把 rule-based reward、ORM、GRM 组合成 hybrid reward system，并明确写出三者权衡——规则奖励精确但窄，ORM 低方差但容易被 exploit，GRM 更灵活但方差更高；</li>
  <li><strong>Forge</strong>：进一步把中间过程质量和<strong>任务完成时间</strong>都纳入 agent RL——真实用户需要的不是”最终做对但过程低效、等待很久”的系统，而是”既能做对、又能较快完成”的 agent。</li>
</ul>

<p><strong>对 reward 正确的理解是”工作方式的规范化”，而不只是”答案质量的评分器”</strong>：</p>

<ul>
  <li>K2.5 用多个 GRM rubric，是因为单一偏好信号太容易被过拟合；</li>
  <li>M1 专门处理 GenRM 长度偏置，是因为 reward model 一旦系统性偏向 verbose response，整个 RL 就会被带偏；</li>
  <li>Forge 引入完成时间相关奖励，是因为真实部署中 agent 的效用不只由正确率决定，还取决于实际耗时。</li>
</ul>

<p>Reward design 的关键<strong>不是给模型更多分数</strong>，而是把 correctness、quality、efficiency、robustness 拆开，再决定哪些可以硬验证、哪些要用模型判断、哪些必须通过对抗测试和 OOD transfer 来防止被投机。</p>

<h3 id="7-记忆层级与并行-agent被训练的对象已经不只是-token-policy而是-operating-policy">7. 记忆、层级与并行 Agent：被训练的对象已经不只是 Token Policy，而是 Operating Policy</h3>

<p>很多人一谈 long-context agent 就想”把 context window 做大一点”。但<strong>长上下文不等于记忆，更不等于好的 agent</strong>。核心问题是：当交互历史越来越长、工具观察越来越多时，模型如何决定什么该保留、什么该丢弃、什么该压缩、什么时候拆任务、什么时候并行多个子 agent？</p>

<ul>
  <li><strong>MiniMax Forge 的 Context Rot</strong>：即使没有触到绝对 context window 上限，长轮次交互中累积的中间推理和冗余 observation 也会造成 attention dilution，让模型失焦。于是 Forge 直接把 <strong>Context Management 纳入 RL 交互回路</strong>，把它当作一种显式 action，让 context transition 本身成为环境状态转移的一部分；</li>
  <li><strong>GLM-5</strong> 在搜索 agent 上也观察到极长上下文会明显伤害性能，因此使用 <strong>keep-recent-k 与 discard-all 的层级式 context management</strong>；</li>
  <li><strong>K2.5 的 Agent Swarm 与 PARL</strong>：当单 agent 顺序执行的延迟变得不可接受时，让 orchestrator 学会<strong>动态任务分解、子 agent 创建和并行调度</strong>。训练时只更新 orchestrator、冻结 sub-agent，以规避最难的 credit ambiguity 与训练不稳定。</li>
</ul>

<p><strong>被优化的对象已经从”token 级生成策略”扩展成”操作系统级策略”</strong>——模型不再只决定下一个 token，而是在决定：</p>

<ul>
  <li>算力怎么花</li>
  <li>上下文怎么管</li>
  <li>任务怎么拆</li>
  <li>子 agent 怎么协作</li>
</ul>

<p>K2.5 的一个关键 insight：<strong>真正的并行 agent 不是把同一个模型复制几份并发运行</strong>，而是让 orchestrator 学会”什么时候值得并行、如何分配子任务、如何在最终汇总时保持全局一致性”。Forge 则强调：记忆管理如果只在 inference 端手工加规则、训练时没见过这种状态转移，最终会形成严重的 inference-training mismatch。</p>

<p>未来 agentic RL 的 frontier，未必是让模型”再想更久”，而是<strong>把 memory editing、hierarchical decomposition 和 agent orchestration 一起纳入训练目标</strong>。</p>

<h3 id="8-infra-基础设施它不是承载算法的底座而是在塑造训练分布">8. Infra 基础设施：它不是承载算法的底座，而是在塑造训练分布</h3>

<p>如果说 RLHF 是在一个相对规整的 prompt → completion → reward → update 闭环上做优化，那么 Agentic RL 面对的是<strong>长短极不均匀、工具调用密集、环境反馈异步、动作语义复杂</strong>的真实交互轨迹。</p>

<p>在这种设定下，基础设施直接决定：</p>

<ul>
  <li>rollout 以什么顺序完成</li>
  <li>哪些样本因过时被丢弃</li>
  <li>哪些前缀能够复用</li>
  <li>训练端和推理端看到的是否还是同一个动作空间</li>
</ul>

<p>这里有三层 infra：</p>

<p><strong>① 塑造训练分布的 rollout / learner 基础设施</strong>。由于任务完成时间可能从秒级跨到小时级，同步 on-policy 几乎不可能，系统必须处理 actor–learner 解耦、队列调度、buffer freshness、checkpoint staleness、partial rollout reuse、stale sample filtering。MiniMax Forge 把 strict FIFO / greedy async / Windowed FIFO 的权衡直接写成”吞吐与分布稳定之间的核心矛盾”；GLM-5 通过异步 generation-training 解耦 + TITO + double-sided importance sampling 控制偏移；K1.5 的 partial rollout reuse 说明<strong>长轨迹能否被复用，本身就是训练 recipe 的一部分</strong>。这一层 infra 直接塑造了”模型真正看到的训练分布”。</p>

<p><strong>② 提升吞吐与成本效率的规模化训练 / 推理 infra</strong>。包括训练 / 推理解耦、数据池缓存、KV / prefix 复用、动态 batching、各种并行化和异构资源调度策略。它们解决的核心问题不是”单点算法是否成立”，而是”这些方法能否在现实成本下真正跑到足够规模”。对 agent workload 来说，模型生成、环境执行、工具调用、verifier 计算、日志存储的资源瓶颈完全不同，基础设施<strong>必须是解耦和分层的</strong>，不能继续沿用单一、同步、同构的训练范式。</p>

<p><strong>③ 保证数值一致性和训练—推理一致性的 serving infra</strong>。最容易被低估但其实最关键：Agentic RL 优化的不是抽象文本，而是<strong>具有明确执行语义的动作序列</strong>——训练时、采样时、部署时对动作的表示或接口稍有错位，模型学到的策略就可能在上线时部分失效。GLM-5 的 TITO 之所以重要，不只是为了省一次 re-tokenization，而是为了精确保持 sampled action 与 optimized action 的对应；MiniMax Forge 的 gateway 与 middleware 设计本质上也在做 action interface standardization。因此，tokenizer / engine 对齐、tool schema 标准化、trajectory serialization、metadata logging、train-serving alignment——都不再只是工程细节，而是在决定<strong>训练时被优化的动作，是否真的是部署时会执行的那个动作</strong>。</p>

<h3 id="9-评测与可观测性测不出来的不变量就守不住">9. 评测与可观测性：测不出来的不变量，就守不住</h3>

<p>前面八节讲了”在哪里守不变量”，但有一个被大多数文章忽视的基础问题：<strong>如果你连三个不变量是否正在被守住都测不出来，就根本无从管理它们</strong>。</p>

<p>Agentic RL 的 evaluation 不能只看 Pass@1 或 final reward，至少需要三类互补的观测维度：</p>

<p><strong>① 探索健康度（对应第一不变量）</strong>：</p>

<ul>
  <li>Pass@k、Potential@k、解法簇数量（semantic cluster count）</li>
  <li>行为路径的 scaffold 迁移率（同一能力在不同 scaffold 下的成功率）</li>
  <li>长期 entropy trajectory 与 action-level 多样性（而不仅是 token-level）</li>
</ul>

<p><strong>② 学习信号健康度（对应第二不变量）</strong>：</p>

<ul>
  <li><strong>non-zero advantage ratio</strong>：一个 batch 内多少 group 产生了非零梯度</li>
  <li><strong>gate-open probability</strong>：group-based 方法中 advantage 有效的样本比例</li>
  <li><strong>组内 reward spread</strong> 和 <strong>gradient SNR</strong></li>
  <li><strong>单位训练时间的有效样本率</strong>（effective tokens per GPU-hour）</li>
</ul>

<p>这些量往往比 loss 曲线更能解释”为什么训练看起来还在跑，但能力没长”。</p>

<p><strong>③ 分布一致性（对应第三不变量）</strong>：</p>

<ul>
  <li><strong>training–serving KL</strong>：相同 prompt 下训练 checkpoint 与部署 checkpoint 的输出分布差异</li>
  <li><strong>rollout staleness 分布</strong>：样本被生成时的策略与被学习时的策略相隔多少步</li>
  <li><strong>tokenizer / tool schema mismatch 率</strong>：训练端与部署端接口一致性的硬指标</li>
  <li><strong>长轨迹误差累积曲线</strong>：模型表现随交互步数退化的速度</li>
</ul>

<p>在更高层次上，还需要一套<strong>对抗性评测</strong>：verifier-hacking 检测、reward-model OOD 探针、scaffold 替换测试、工具噪声注入测试——这些不是”锦上添花的 benchmark”，而是<strong>第一不变量和第二不变量是否被守住的直接证据</strong>。</p>

<p>SWE-Universe 把 hacking detection 自动化进环境，本质上就是在承认：<strong>评测已经不是 pipeline 的末端，而是训练系统的一部分</strong>。没有这层观测，所谓”调参”就只是在黑箱里做随机扰动。</p>

<hr />

<h2 id="四结语agentic-rl-的真正竞争不在单点算法">四、结语：Agentic RL 的真正竞争，不在单点算法</h2>

<p>回到开头那句话——<strong>Agentic RL 的核心问题，已经从”怎么更新参数”扩展为”怎么在真实 Agent 环境里持续制造可用的学习信号”</strong>。</p>

<p>把三条技术路线放在一起看，信号非常清楚：</p>

<ul>
  <li><strong>Kimi 路线</strong>告诉我们：(1) 长上下文本身是一条 RL scaling axis；(2) 复杂的 value function / MCTS / process RM 不是唯一道路，简洁但分布一致的 policy optimization 也能跑出很强的长链能力；(3) 当 agent 工作流变复杂后，奖励模型、token-level clipping、token efficiency 控制和 learned parallel orchestration 会越来越重要。K1.5 → K2.5 的演进，本质上是从”把长 reasoning RL 跑通”走向”把多步 agentic / multimodal RL 规模化”；</li>
  <li><strong>MiniMax 路线</strong>说明：长时程 agent RL 一进到真实环境，首要问题很快就从”模型能不能推理”转向”<strong>系统能不能稳定地持续学习</strong>“。M1 的 CISPO 的价值在于修复长轨迹 RL 的 off-policy 和梯度裁剪副作用；Forge 进一步证明，异步调度、上下文管理、完成时间奖励、跨任务联合训练、前缀树合并这类”看起来很工程”的东西，<strong>实际上决定了你最终能否在大规模真实环境里把 RL 跑起来</strong>；</li>
  <li><strong>GLM 路线</strong>强调：后训练不应该一股脑混在一起，而要按能力类型分阶段组织，并借助蒸馏机制保护已有能力。Reasoning RL → Agentic RL → General RL 的顺序<strong>不只是训练日程安排，而是一种能力编排方式</strong>。GLM-5 对异步 RL 基础设施、TITO、double-sided importance sampling 的强调，也再次说明：<strong>训练系统与策略优化之间已经没有清晰边界</strong>。</li>
</ul>

<p>综合这些路线，一个清晰的结论是：</p>

<blockquote>
  <p>Agentic RL 不只是”更大模型 × 更多数据 × 更多 token”，而是：</p>

  <ul>
    <li>更丰富的<strong>环境覆盖</strong></li>
    <li>更高密度的<strong>有效学习信号</strong></li>
    <li>更一致的 <strong>rollout / update / serving 分布</strong></li>
    <li>更高的<strong>单位时间有效训练效率</strong></li>
    <li>以及能让你<strong>确认前四者正在发生</strong>的评测与可观测性</li>
  </ul>
</blockquote>

<p>在完善高效的 infra 支持下，谁在这五个维度上同时做得更好，谁就更可能真正把 agent 训出来。</p>

<hr />

<h2 id="参考文献">参考文献</h2>

<p>[1] Kimi Team. <em>Kimi k1.5: Scaling Reinforcement Learning with LLMs</em>. arXiv:2501.12599, 2025. <a href="https://arxiv.org/abs/2501.12599">https://arxiv.org/abs/2501.12599</a></p>

<p>[2] Kimi Team. <em>Kimi K2: Open Agentic Intelligence</em>. arXiv:2507.20534, 2025. <a href="https://arxiv.org/abs/2507.20534">https://arxiv.org/abs/2507.20534</a></p>

<p>[3] Kimi Team. <em>Kimi K2.5: Visual Agentic Intelligence</em>. arXiv:2602.02276, 2026. <a href="https://arxiv.org/abs/2602.02276">https://arxiv.org/abs/2602.02276</a></p>

<p>[4] Ziniu Li et al. <em>Preserving Diversity in Supervised Fine-Tuning of Large Language Models</em>. arXiv:2408.16673, 2024. <a href="https://arxiv.org/abs/2408.16673">https://arxiv.org/abs/2408.16673</a></p>

<p>[5] Ziniu Li et al. <em>ReMax: A Simple, Effective, and Efficient Reinforcement Learning Method for Aligning Large Language Models</em>. arXiv:2310.10505, 2023. <a href="https://arxiv.org/abs/2310.10505">https://arxiv.org/abs/2310.10505</a></p>

<p>[6] Ziniu Li et al. <em>Knapsack RL: Unlocking Exploration of LLMs via Optimizing Budget Allocation</em>. arXiv:2509.25849, 2025. <a href="https://arxiv.org/abs/2509.25849">https://arxiv.org/abs/2509.25849</a></p>

<p>[7] Hanze Dong. <em>Curate the Learning Signal for Reinforcement Learning: Variance Minimization, Adaptive Sampling, and Self-Hinting</em>. Blog post, 2026. <a href="https://hendrydong.github.io/blogs/pages/rl-ada.html">https://hendrydong.github.io/blogs/pages/rl-ada.html</a></p>

<p>[8] GLM-5 Team. <em>GLM-5: from Vibe Coding to Agentic Engineering</em>. arXiv:2602.15763, 2026. <a href="https://arxiv.org/abs/2602.15763">https://arxiv.org/abs/2602.15763</a></p>

<p>[9] MiniMax. <em>Forge: Scalable Agent RL Framework and Algorithm</em>. MiniMax News, 2026. <a href="https://www.minimax.io/news/forge-scalable-agent-rl-framework-and-algorithm">https://www.minimax.io/news/forge-scalable-agent-rl-framework-and-algorithm</a></p>

<p>[10] Mouxiang Chen et al. <em>SWE-Universe: Scale Real-World Verifiable Environments to Millions</em>. arXiv:2602.02361, 2026. <a href="https://arxiv.org/abs/2602.02361">https://arxiv.org/abs/2602.02361</a></p>

<p>[11] MiniMax. <em>MiniMax-M1: Scaling Test-Time Compute Efficiently with Lightning Attention</em>. arXiv:2506.13585, 2025. <a href="https://arxiv.org/abs/2506.13585">https://arxiv.org/abs/2506.13585</a></p>

<p>[12] DeepSeek-AI. <em>DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning</em>. arXiv:2501.12948, 2025. <a href="https://arxiv.org/abs/2501.12948">https://arxiv.org/abs/2501.12948</a></p>

<p>[13] Xinran Li et al. <em>Getting Your LLMs Ready for Reinforcement Learning with Lightweight SFT</em>. OpenReview / ICLR 2026. <a href="https://openreview.net/forum?id=yezWGJmODg">https://openreview.net/forum?id=yezWGJmODg</a></p>

<p>[14] Qiyuan Yu et al. <em>DAPO: An Open-Source LLM Reinforcement Learning System at Scale</em>. arXiv:2503.14476, 2025. <a href="https://arxiv.org/abs/2503.14476">https://arxiv.org/abs/2503.14476</a></p>

<p>[15] Jian Yao et al. <em>Diversity-Aware Policy Optimization for Large Language Model Reasoning</em>. arXiv:2505.23433, 2025. <a href="https://arxiv.org/abs/2505.23433">https://arxiv.org/abs/2505.23433</a></p>

<p>[16] Xingyu Dang et al. <em>Assessing Diversity Collapse in Reasoning</em>. OpenReview, 2025. <a href="https://openreview.net/forum?id=AMiKsHLjQh">https://openreview.net/forum?id=AMiKsHLjQh</a></p>

<p>[17] Jiarui Yao et al. <em>Optimizing Chain-of-Thought Reasoners via Gradient Variance Minimization in Rejection Sampling and RL</em>. arXiv:2505.02391, 2025. <a href="https://arxiv.org/abs/2505.02391">https://arxiv.org/abs/2505.02391</a></p>

<p>[18] Hieu Trung Nguyen et al. <em>Adaptive Rollout Allocation for Online Reinforcement Learning with Verifiable Rewards</em>. arXiv:2602.01601, 2026. <a href="https://arxiv.org/abs/2602.01601">https://arxiv.org/abs/2602.01601</a></p>

<p>[19] Wei Xiong et al. <em>Reinforce-Ada: An Adaptive Sampling Framework for Reinforce-Style LLM Training</em>. arXiv:2510.04996, 2025. <a href="https://arxiv.org/abs/2510.04996">https://arxiv.org/abs/2510.04996</a></p>]]></content><author><name>李勇志 (Yongzhi Li)</name><email>yongzhili@pku.edu.cn</email></author><category term="blog" /><category term="reinforcement learning" /><category term="agentic" /><category term="llm" /><category term="post-training" /><category term="infrastructure" /><summary type="html"><![CDATA[过去一年，各家大模型公司公开的技术报告透出的最重要信号，不是又出现了一个更好的 PPO/GRPO 变体，而是真正有效的 Agentic RL 已经从”单轮文本优化”转向了”在长上下文、工具调用、部分可观测、异步执行环境中的系统性策略学习”。 Kimi K1.5[1] 把长上下文 RL、partial rollout 重用和 mirror-descent 风格的 policy optimization 拉到了台前；Kimi K2[2]/K2.5[3] 又把 agentic 数据合成、多模态 RL、token-level clipping、GRM rubric、Toggle、PARL / Agent Swarm 这些关键部件公开；MiniMax 把另一个事实讲得更彻底：当 rollout 时长从秒级扩到小时级，训练瓶颈就不再是 loss design，而是吞吐、稳定性与 agent 灵活性之间的三难权衡；GLM 则强调分阶段 RL：Reasoning RL、Agentic RL、General RL 不是混在一起一次训完，而是通过顺序化 pipeline 逐步推进，并借助异步 RL 基础设施与跨阶段蒸馏来兼顾长时程 agent 学习与能力保持。 Agentic RL 的核心问题，已经从”怎么更新参数”扩展为”怎么在真实 Agent 环境里持续制造可用的学习信号，并用在线交互的轨迹数据驱动优化“。]]></summary></entry><entry><title type="html">图解 Wan2.1 I2V：从一张图到一段视频，模型到底发生了什么</title><link href="https://liyongzhi.xyz/posts/2026/04/wan21-i2v-explained/" rel="alternate" type="text/html" title="图解 Wan2.1 I2V：从一张图到一段视频，模型到底发生了什么" /><published>2026-04-24T00:00:00+08:00</published><updated>2026-04-24T00:00:00+08:00</updated><id>https://liyongzhi.xyz/posts/2026/04/blog-post-wan21-i2v-explained</id><content type="html" xml:base="https://liyongzhi.xyz/posts/2026/04/wan21-i2v-explained/"><![CDATA[<p>最近视频生成模型卷得很快，<code class="language-plaintext highlighter-rouge">Wan2.1</code> 是阿里 Wan 团队开源的那一套。它最常用的场景之一就是 <strong>I2V（Image-to-Video）</strong>：给一张参考图加一句文字 prompt，模型给你生成一段几秒的视频，首帧基本还是那张图，后续的镜头就按你写的文字去演。</p>

<p>这篇文章想做的事情是：</p>

<blockquote>
  <p>把 Wan2.1 I2V 里<strong>每一步数据发生了什么</strong>讲清楚，让从没接触过视频生成的人也能看懂。</p>
</blockquote>

<p>我们会从最外层的”图像 + 文字 → 视频”讲起，一路剥开壳子：<br />
VAE 到底在压缩什么、CLIP 和 T5 各自管什么、DiT 内部是怎么把图像信息和文字信息混进去的、采样循环为什么要跑那么多步、以及为什么首帧会这么”像”你给的那张图。</p>

<p><br />
<img align="center" width="1000" src="https://liyongzhi.xyz/images/posts/wan21-i2v-overview.svg" alt="Wan2.1 I2V overall architecture" />
<br /></p>

<p>这张图是全文的总地图。下面的每一节都是在放大它的某一块。</p>

<hr />

<h2 id="1-先做一次外行翻译i2v-到底在做什么">1. 先做一次”外行翻译”：I2V 到底在做什么</h2>

<p>如果用一句日常语言来描述 I2V，其实是：</p>

<blockquote>
  <p>我们有一张图（<code class="language-plaintext highlighter-rouge">3 × H × W</code>，RGB 像素），想把它”续写”成一段视频（<code class="language-plaintext highlighter-rouge">3 × F × H × W</code>，F 帧），而且这段视频的内容要符合文字 prompt。</p>
</blockquote>

<p>朴素想法是直接训练一个”图 + 文字 → 视频”的网络。问题有二：</p>

<ol>
  <li>视频的体积太大。即便是 480p × 24fps × 4 秒，也已经是 1.1 亿像素级别，直接建模太贵。</li>
  <li>我们希望生成过程是<strong>可控的</strong>——能调 guidance，能控制风格，能多步修正——而不是一次性跑完一个巨大网络就结束。</li>
</ol>

<p>Diffusion 模型的套路恰好能解决这两件事：</p>

<ul>
  <li><strong>压缩</strong>：用 VAE 把视频压到一个小很多的 latent 空间，之后所有运算都在 latent 上做。</li>
  <li><strong>迭代</strong>：扩散模型天然是多步的，每一步都在”把更接近噪声的视频”往”更清晰的视频”方向推一点。</li>
</ul>

<p>所以 Wan2.1 I2V 的骨架分成两大块：</p>

<table>
  <thead>
    <tr>
      <th>模块</th>
      <th>角色</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><strong>Wan-VAE</strong></td>
      <td>像素 ⇄ latent 的翻译员</td>
    </tr>
    <tr>
      <td><strong>DiT</strong></td>
      <td>在 latent 空间里”去噪”的大脑</td>
    </tr>
  </tbody>
</table>

<p>外加两个<strong>条件编码器</strong>：</p>

<table>
  <thead>
    <tr>
      <th>模块</th>
      <th>角色</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><strong>CLIP ViT-H/14</strong></td>
      <td>把参考图变成”这张图看起来讲了什么”的高层语义向量</td>
    </tr>
    <tr>
      <td><strong>umT5</strong></td>
      <td>把文字 prompt 编码成一串 token embedding</td>
    </tr>
  </tbody>
</table>

<p>接下来我们分别看每一块。</p>

<hr />

<h2 id="2-wan-vae把视频压缩-256-倍后再还原">2. Wan-VAE：把视频压缩 256 倍后再还原</h2>

<p><code class="language-plaintext highlighter-rouge">Wan-VAE</code> 是一个 <strong>3D Causal VAE</strong>。它做的事很朴素：</p>

<ul>
  <li>输入：<code class="language-plaintext highlighter-rouge">[3, F, H, W]</code> 的视频（或单张图当作 <code class="language-plaintext highlighter-rouge">F=1</code>）</li>
  <li>输出：<code class="language-plaintext highlighter-rouge">[16, F/4, H/8, W/8]</code> 的 latent</li>
</ul>

<p>换句话说：</p>

<ul>
  <li>空间下采样 <code class="language-plaintext highlighter-rouge">8×8</code> 倍</li>
  <li>时间下采样 <code class="language-plaintext highlighter-rouge">4</code> 倍</li>
  <li>通道数从 <code class="language-plaintext highlighter-rouge">3</code> 变成 <code class="language-plaintext highlighter-rouge">16</code>（表达能力变强）</li>
</ul>

<p>总体积约压缩 <strong>256 倍</strong>（<code class="language-plaintext highlighter-rouge">8·8·4 / (16/3) ≈ 24×24 / ...</code>，算下来大约 48× 的”信息体积”，但浮点数要少 200+ 倍）。</p>

<blockquote>
  <p><strong>为什么叫 Causal？</strong> 指的是它的时间卷积只看”过去”不看”未来”，这样可以支持变长视频、流式推理，和后续滚动生成新帧。</p>
</blockquote>

<p>一个关键点是 <strong>I2V 里 VAE 会被用两次</strong>：</p>

<ol>
  <li><strong>编码参考图</strong>：把那张图当成一个 <code class="language-plaintext highlighter-rouge">F=1</code> 的视频编码，得到它的 latent。</li>
  <li><strong>解码最终视频</strong>：DiT 输出 latent，扔给 VAE 解码回像素视频。</li>
</ol>

<p>其中第一次编码的结果被塞进 DiT 作为”低层像素/结构”条件——这是后面讲 I2V 双路条件时的关键一环。</p>

<hr />

<h2 id="3-两路文字图像条件编码clip-和-t5-各自做什么">3. 两路文字/图像条件编码：CLIP 和 T5 各自做什么</h2>

<p>这两个模型很多人容易搞混，但它们在 Wan2.1 里分工很清晰。</p>

<h3 id="31-umt5把文字变成-512--4096-的-token-序列">3.1 umT5：把文字变成 512 × 4096 的 token 序列</h3>

<p><code class="language-plaintext highlighter-rouge">umT5</code> 是 T5 的多语言版。输入是你的 prompt，输出是：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[seq_len, 4096]   # 每个 token 一个 4096 维向量
</code></pre></div></div>

<p>Wan2.1 统一把这个序列 padding / truncate 到 <strong>512 个 token</strong>，所以文本总是 <code class="language-plaintext highlighter-rouge">[512, 4096]</code>。</p>

<blockquote>
  <p>T5 是一个纯文本的大模型，它的向量很”语言化”，擅长表达语义、句法关系。</p>
</blockquote>

<h3 id="32-clip-vit-h14把图像变成-257--1280-的-token-序列">3.2 CLIP ViT-H/14：把图像变成 257 × 1280 的 token 序列</h3>

<p><code class="language-plaintext highlighter-rouge">CLIP</code> 是一个<strong>跨模态</strong>模型（图像 + 文本对齐训练的），这里我们只用它的<strong>图像编码器</strong>（ViT-H/14）。</p>

<p>它吃一张 <code class="language-plaintext highlighter-rouge">224 × 224</code> 的图，输出：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[257, 1280]
</code></pre></div></div>

<p><strong>257 从哪来？</strong> 这是一个很常见的数字：</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">ViT-H/14</code> 把 224×224 切成 <code class="language-plaintext highlighter-rouge">14×14</code> 的 patch</li>
  <li><code class="language-plaintext highlighter-rouge">224 / 14 = 16</code>，所以一张图变成 <code class="language-plaintext highlighter-rouge">16 × 16 = 256</code> 个 patch token</li>
  <li>再加一个 <code class="language-plaintext highlighter-rouge">CLS</code> token，总共 <strong>257</strong> 个</li>
</ul>

<p>每个 token 的通道数是 <code class="language-plaintext highlighter-rouge">1280</code>（ViT-H 的隐藏维度）。</p>

<blockquote>
  <p>CLIP 给出的是<strong>图像的高层语义</strong>：它知道这张图里是”一只猫”、”傍晚的海边”、”油画风格”之类的语义抽象，但几乎不保留像素级的精细结构。</p>
</blockquote>

<h3 id="33-clip-vs-t5为什么两个都要">3.3 CLIP vs T5：为什么两个都要？</h3>

<p>这是 I2V 非常关键的一点。两者的”关注点”不一样：</p>

<table>
  <thead>
    <tr>
      <th> </th>
      <th>擅长</th>
      <th>不擅长</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><strong>T5</strong></td>
      <td>文字描述的动作、意图、场景</td>
      <td>图像具体长什么样</td>
    </tr>
    <tr>
      <td><strong>CLIP</strong></td>
      <td>参考图的整体风格、主体</td>
      <td>精确的像素/空间结构</td>
    </tr>
  </tbody>
</table>

<p>所以两者是<strong>互补</strong>的——都给 DiT 看一遍，DiT 再自己挑。这也是为什么后面会看到 cross-attention 是”双流”的。</p>

<hr />

<h2 id="4-把参考图的像素也塞进模型条件-latent-y">4. 把”参考图的像素”也塞进模型：条件 latent <code class="language-plaintext highlighter-rouge">y</code></h2>

<p>到这里我们已经有了两条图像通路：CLIP（语义）和 T5（文字）。但对 I2V 来说，仅靠 CLIP 的语义是不够的——生成的第一帧如果不能”长得非常像”输入图，用户立刻会觉得不对。</p>

<p>于是 Wan2.1 加了第三条通路：<strong>把参考图用 VAE 编码后，直接在通道维度拼到噪声 latent 上</strong>。</p>

<h3 id="41-构造-y">4.1 构造 <code class="language-plaintext highlighter-rouge">y</code></h3>

<p>假设目标视频是 <code class="language-plaintext highlighter-rouge">F</code> 帧，latent 形状 <code class="language-plaintext highlighter-rouge">[16, F/4, H/8, W/8]</code>。我们把 <code class="language-plaintext highlighter-rouge">T_latent = F/4</code>。</p>

<p><strong>第 1 步：把参考图放到第 0 帧，其余帧置零。</strong></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">video_clip</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span>
    <span class="n">img_resized</span><span class="p">,</span>                <span class="c1"># [3, 1, H, W]  ← 第 0 帧 = 参考图
</span>    <span class="n">zeros</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">F</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">)</span>         <span class="c1"># 其余帧为 0
</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>                       <span class="c1"># → [3, F, H, W]
</span></code></pre></div></div>

<p><strong>第 2 步：VAE 编码。</strong></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">y_latent</span> <span class="o">=</span> <span class="n">VAE</span><span class="p">.</span><span class="n">encode</span><span class="p">(</span><span class="n">video_clip</span><span class="p">)</span>   <span class="c1"># → [16, T_latent, H/8, W/8]
</span></code></pre></div></div>

<p><strong>第 3 步：构造时间 mask，标记”哪些帧是已知的”。</strong></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">msk</span> <span class="o">=</span> <span class="n">ones</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">F</span><span class="p">,</span> <span class="n">H_lat</span><span class="p">,</span> <span class="n">W_lat</span><span class="p">)</span>
<span class="n">msk</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="mi">0</span>                         <span class="c1"># 只有第 0 帧 = 1
# 把 msk[:, 0:1] 沿时间 repeat 4 次，和 msk[:, 1:] 拼接
# 再 reshape 成 [4, T_latent, H_lat, W_lat]
</span></code></pre></div></div>

<p>这里的 <code class="language-plaintext highlighter-rouge">4</code> 是 VAE 的时间 stride——我们需要让 mask 通道数足够”表达”被 VAE 压缩掉的时间细节。</p>

<p><strong>第 4 步：mask 和 VAE latent 通道拼接，得到 <code class="language-plaintext highlighter-rouge">y</code>。</strong></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">y</span> <span class="o">=</span> <span class="n">concat</span><span class="p">([</span><span class="n">msk</span><span class="p">,</span> <span class="n">y_latent</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>    <span class="c1"># [4 + 16 = 20, T_latent, H_lat, W_lat]
</span></code></pre></div></div>

<blockquote>
  <p>把 <code class="language-plaintext highlighter-rouge">y</code> 想成一块”透明纸”：第 0 帧那一层写满了”你要照着这张图画”，其它帧那一层是空白，同时还有一层专门标注”哪里非空白”。</p>
</blockquote>

<h3 id="42-y-怎么进-dit">4.2 <code class="language-plaintext highlighter-rouge">y</code> 怎么进 DiT</h3>

<p>DiT 的输入是噪声 latent <code class="language-plaintext highlighter-rouge">x_t: [16, T_latent, H_lat, W_lat]</code>。进网络前做一次通道拼接：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>x = concat(x_t, y) = [16 + 20, T, H, W] = [36, T, H, W]
</code></pre></div></div>

<p>所以 I2V 的 DiT 输入通道是 <strong>36</strong>（T2V 是 16）。也正因为这个差别，I2V checkpoint 的 <code class="language-plaintext highlighter-rouge">patch_embedding</code> 卷积权重和 T2V 不是一回事。</p>

<hr />

<h2 id="5-dit-内部一个时间步里到底跑了什么">5. DiT 内部：一个时间步里到底跑了什么</h2>

<p>接下来进入最核心的部分。我们放一下 DiT 单层的结构图：</p>

<p><br />
<img align="center" width="1100" src="https://liyongzhi.xyz/images/posts/wan21-i2v-block.svg" alt="Wan2.1 DiT block internal" />
<br /></p>

<p>整体看，DiT 是一个典型的 Transformer 栈，但有三个重要定制：</p>

<ol>
  <li><strong>时空 3D RoPE</strong>（self-attention 里的位置编码）</li>
  <li><strong>双流 cross-attention</strong>（image KV + text KV）</li>
  <li><strong>AdaLN-Zero 风格的 timestep 调制</strong></li>
</ol>

<p>下面一条一条讲。</p>

<h3 id="51-patchify把视频-latent-变成-transformer-的-token-序列">5.1 Patchify：把视频 latent 变成 Transformer 的 token 序列</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="bp">self</span><span class="p">.</span><span class="n">patch_embedding</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Conv3d</span><span class="p">(</span><span class="mi">36</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">),</span> <span class="n">stride</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">))</span>
</code></pre></div></div>

<p>这是一个”<strong>3D patchify</strong>“：用 <code class="language-plaintext highlighter-rouge">Conv3d</code> 把每个 <code class="language-plaintext highlighter-rouge">1×2×2</code> 的时空小块打成一个 token。</p>

<ul>
  <li>时间方向 kernel=1，意味着<strong>时间维度不被合并</strong>（每一个 latent 帧仍然是独立的一层 token）。</li>
  <li>空间方向 kernel=2，把 <code class="language-plaintext highlighter-rouge">H_lat × W_lat</code> 的网格再进一步压 <code class="language-plaintext highlighter-rouge">2×2</code>，得到 <code class="language-plaintext highlighter-rouge">H_lat/2 × W_lat/2</code> 个 token。</li>
</ul>

<p>最终序列长度是：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>S = T_latent × (H_lat/2) × (W_lat/2)
</code></pre></div></div>

<p>每个 token 是 <code class="language-plaintext highlighter-rouge">dim</code> 维向量（1.3B 版本里 <code class="language-plaintext highlighter-rouge">dim=2048</code>）。</p>

<h3 id="52-timestep-embedding让每层都知道现在在第几步">5.2 Timestep embedding：让每层都知道”现在在第几步”</h3>

<p>扩散模型的一个关键差别是每一步的处理方式不一样。T=T_max 时几乎全是噪声，T=0 时已经是完整视频，所以模型在不同 step 应该”轻重不一”。</p>

<p>Wan2.1 的做法是 <strong>AdaLN-Zero</strong>（DiT 论文里的那一套）：</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">e</span> <span class="o">=</span> <span class="n">sinusoidal_embedding_1d</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>          <span class="c1"># 标量 t → 256 维向量
</span><span class="n">e</span> <span class="o">=</span> <span class="n">time_embedding</span><span class="p">(</span><span class="n">e</span><span class="p">)</span>                        <span class="c1"># MLP 投到 dim
</span><span class="n">e0</span> <span class="o">=</span> <span class="n">time_projection</span><span class="p">(</span><span class="n">e</span><span class="p">).</span><span class="n">unflatten</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="mi">6</span><span class="p">,</span><span class="n">dim</span><span class="p">))</span>  <span class="c1"># 再投成 [B, 6, dim]
</span></code></pre></div></div>

<p>然后把这 6 份向量分发给每个 block，块内再加上自己可学习的 <code class="language-plaintext highlighter-rouge">modulation</code> 参数，切成 6 组：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(shift1, scale1, gate1,  shift2, scale2, gate2)
</code></pre></div></div>

<ul>
  <li><code class="language-plaintext highlighter-rouge">shift, scale</code> 用在 LayerNorm 之后：<code class="language-plaintext highlighter-rouge">x' = norm(x) · (1 + scale) + shift</code></li>
  <li><code class="language-plaintext highlighter-rouge">gate</code> 用在残差分支：<code class="language-plaintext highlighter-rouge">x = x + gate · f(x')</code></li>
</ul>

<blockquote>
  <p><strong>“Zero” 的含义</strong>：<code class="language-plaintext highlighter-rouge">gate</code> 初始化为 0，使得模型训练开始时每个 block 都是恒等映射——DiT 从一个干净的起点开始学。</p>
</blockquote>

<p>注意：cross-attention <strong>不被 AdaLN 调制</strong>，只有 self-attention 和 FFN 被调制。</p>

<h3 id="53-self-attention3d-全局注意力--分解式-rope">5.3 Self-Attention：3D 全局注意力 + 分解式 RoPE</h3>

<p>这一步做的事很简单：<strong>视频 token 之间互相看</strong>。</p>

<p>代码上是标准的 QKV flash attention，但有两处定制：</p>

<p><strong>① QK 做 RMSNorm</strong>。这是稳定训练用的技巧：</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">q</span> <span class="o">=</span> <span class="n">RMSNorm</span><span class="p">(</span><span class="n">Linear_q</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="n">k</span> <span class="o">=</span> <span class="n">RMSNorm</span><span class="p">(</span><span class="n">Linear_k</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="n">v</span> <span class="o">=</span> <span class="n">Linear_v</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</code></pre></div></div>

<p><strong>② 3D 分解式 RoPE 作用在 Q/K 上</strong>（不作用于 V）。</p>

<p>视频 token 有三个坐标：<code class="language-plaintext highlighter-rouge">(frame, height, width)</code>。Wan2.1 把每个 head 的维度 <code class="language-plaintext highlighter-rouge">d</code> 切成三段：</p>

<table>
  <thead>
    <tr>
      <th>段</th>
      <th>通道数</th>
      <th>编码的是</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>时间</td>
      <td><code class="language-plaintext highlighter-rouge">d − 4·(d/6)</code></td>
      <td>帧索引 <code class="language-plaintext highlighter-rouge">f</code></td>
    </tr>
    <tr>
      <td>高</td>
      <td><code class="language-plaintext highlighter-rouge">2·(d/6)</code></td>
      <td>行索引 <code class="language-plaintext highlighter-rouge">h</code></td>
    </tr>
    <tr>
      <td>宽</td>
      <td><code class="language-plaintext highlighter-rouge">2·(d/6)</code></td>
      <td>列索引 <code class="language-plaintext highlighter-rouge">w</code></td>
    </tr>
  </tbody>
</table>

<p>三段分别应用一维 RoPE（复数旋转），然后沿通道拼回一起：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>q_rot = RoPE_T(q[:, :, :dT], f) ⊕ RoPE_H(q[:, :, dT:dT+dH], h) ⊕ RoPE_W(q[:, :, dT+dH:], w)
</code></pre></div></div>

<p><strong>为什么这样设计？</strong></p>

<ul>
  <li>可以支持<strong>任意分辨率和帧数</strong>，因为 RoPE 是外推良好的位置编码。</li>
  <li>时间/空间的频率独立，模型可以各自学合适的”时间尺度”和”空间尺度”。</li>
  <li>相比绝对位置嵌入，训练时可以在一个尺度下训，推理时换尺度不会崩。</li>
</ul>

<p><strong>注意力的范围是”全 3D 全局”</strong>——所有视频 token 互相能看。这就是为什么视频生成模型这么贵：序列长度是 <code class="language-plaintext highlighter-rouge">T × H × W</code>，attention 是 O(S²)。</p>

<h3 id="54-cross-attention双流融合图像--文字">5.4 Cross-Attention：双流融合图像 + 文字</h3>

<p>到了 I2V 最有意思的设计。先回忆一下 context 长什么样：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>context = [ CLIP_257 ∥ T5_512 ]         # shape [769, dim]
         ─── 前 257 ───   ── 后 512 ──
             (图像)           (文本)
</code></pre></div></div>

<p>如果用朴素 cross-attention，你会一次算 <code class="language-plaintext highlighter-rouge">attn(q, K=k_all, V=v_all)</code>，让视频 token 对这 769 个 token 做 softmax。问题是图像和文本的分布差距很大，softmax 会把注意力偏到一侧。</p>

<p>Wan2.1 的做法是<strong>双流独立</strong>：</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># 共享 Query
</span><span class="n">q</span> <span class="o">=</span> <span class="n">Linear_q</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

<span class="c1"># 图像分支（独立的 k_img, v_img）
</span><span class="n">k_img</span> <span class="o">=</span> <span class="n">Linear_k_img</span><span class="p">(</span><span class="n">context</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">257</span><span class="p">])</span>
<span class="n">v_img</span> <span class="o">=</span> <span class="n">Linear_v_img</span><span class="p">(</span><span class="n">context</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">257</span><span class="p">])</span>

<span class="c1"># 文本分支（共享 T2V 的 k, v）
</span><span class="n">k_txt</span> <span class="o">=</span> <span class="n">Linear_k</span><span class="p">(</span><span class="n">context</span><span class="p">[:,</span> <span class="mi">257</span><span class="p">:])</span>
<span class="n">v_txt</span> <span class="o">=</span> <span class="n">Linear_v</span><span class="p">(</span><span class="n">context</span><span class="p">[:,</span> <span class="mi">257</span><span class="p">:])</span>

<span class="n">out_img</span> <span class="o">=</span> <span class="n">flash_attn</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k_img</span><span class="p">,</span> <span class="n">v_img</span><span class="p">)</span>
<span class="n">out_txt</span> <span class="o">=</span> <span class="n">flash_attn</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k_txt</span><span class="p">,</span> <span class="n">v_txt</span><span class="p">)</span>

<span class="n">out</span> <span class="o">=</span> <span class="n">Linear_o</span><span class="p">(</span><span class="n">out_img</span> <span class="o">+</span> <span class="n">out_txt</span><span class="p">)</span>   <span class="c1"># 逐元素相加再过输出投影
</span></code></pre></div></div>

<p>几个关键设计点：</p>

<ol>
  <li><strong>独立的 K/V 投影</strong>：图像用 <code class="language-plaintext highlighter-rouge">k_img, v_img</code>，文本用 <code class="language-plaintext highlighter-rouge">k, v</code>。每一模态在自己的几何空间里算 attention，不会互相挤压 softmax。</li>
  <li><strong>两次独立 attention 再相加</strong>：相当于两种信号<strong>分别</strong>给每个视频 token 打了一次分，再叠加作为新的残差。</li>
  <li><strong>Q 共享</strong>：视频 token 只有一份”问题”，问图和文字同一个问题：”你们谁和我相关？”</li>
  <li><strong>无 RoPE</strong>：cross-attn 中的 K/V 是外部序列，不需要视频的时空位置编码。</li>
</ol>

<blockquote>
  <p>直观理解：<strong>image 分支管”我希望长什么样”，text 分支管”我希望怎么演”，两个加在一起就是视频 token 的条件梯度</strong>。</p>
</blockquote>

<h3 id="55-ffn标准-mlp再来一次-adaln-门控">5.5 FFN：标准 MLP，再来一次 AdaLN 门控</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">y</span> <span class="o">=</span> <span class="n">ffn</span><span class="p">(</span><span class="n">norm2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="err">·</span> <span class="p">(</span><span class="mi">1</span><span class="o">+</span><span class="n">scale2</span><span class="p">)</span> <span class="o">+</span> <span class="n">shift2</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">gate2</span> <span class="err">·</span> <span class="n">y</span>
</code></pre></div></div>

<p><code class="language-plaintext highlighter-rouge">ffn</code> 就是常规的 <code class="language-plaintext highlighter-rouge">Linear → GELU → Linear</code>，中间维度是 <code class="language-plaintext highlighter-rouge">4 × dim</code>（比如 dim=2048 时 ffn_dim=8192）。</p>

<p>到这里一个 block 就结束了。把这个 block 叠 32 层（1.3B 版）或 40 层（14B 版），最后过一个 <code class="language-plaintext highlighter-rouge">Head</code>（也带 AdaLN 和 unpatchify），就能把 <code class="language-plaintext highlighter-rouge">[B, S, dim]</code> 变回 <code class="language-plaintext highlighter-rouge">[16, T, H, W]</code>——也就是模型对<strong>当前时间步的”速度场” <code class="language-plaintext highlighter-rouge">v</code></strong> 的预测。</p>

<hr />

<h2 id="6-训练目标为什么叫-flow-matching不再叫预测噪声">6. 训练目标：为什么叫 Flow Matching，不再叫”预测噪声”</h2>

<p>DDPM 早年是让模型预测”这张图里的噪声 <code class="language-plaintext highlighter-rouge">ε</code>“。Wan2.1 用的是 <strong>Flow Matching / Rectified Flow</strong> 的范式——本质上是把扩散过程理解成<strong>一条从噪声到数据的直线路径</strong>，模型学的是这条路径上每一点的”速度”。</p>

<p>具体来说，定义一条插值：</p>

\[x_t = (1 - t) \cdot x_0 + t \cdot \epsilon, \quad t \in [0, 1], \quad \epsilon \sim \mathcal{N}(0, I)\]

<p>那么真值速度就是：</p>

\[v^* = \frac{d x_t}{d t} = \epsilon - x_0\]

<p>训练目标：</p>

\[\mathcal{L}_{\text{FM}} = \mathbb{E}_{x_0, \epsilon, t} \left\| v_\theta(x_t, t, c) - (\epsilon - x_0) \right\|^2\]

<p>其中 <code class="language-plaintext highlighter-rouge">c = {y, CLIP_fea, T5_text}</code> 是所有条件的合集。</p>

<p><strong>Flow Matching 相比预测 ε 有什么好处？</strong></p>

<ul>
  <li>训练 loss 更稳定，对 <code class="language-plaintext highlighter-rouge">t</code> 的依赖更平滑。</li>
  <li>采样时可以用更少的步数。典型配置 <strong>25–50 步</strong>即可出不错结果（早期 DDPM 需要 1000 步）。</li>
  <li>路径”直”这件事意味着模型不容易陷入局部的噪声拟合。</li>
</ul>

<hr />

<h2 id="7-一次完整的推理25-步里到底发生了什么">7. 一次完整的推理：25 步里到底发生了什么</h2>

<p>现在把所有东西串起来。假设你给了一张 <code class="language-plaintext highlighter-rouge">H × W</code> 的图、一句 prompt，让模型生成 F 帧的视频：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>─── 推理前准备（只做 1 次）────────────────────────────────
1. t5_ctx  = umT5(prompt)                      # [512, 4096] → MLP → [512, dim]
2. clip_fea = CLIP.visual(image)               # [257, 1280] → MLPProj → [257, dim]
3. img_lat = VAE.encode([image, zeros, ...])   # [16, T, H_lat, W_lat]
4. msk     = build_mask(first_frame=1)         # [4, T, H_lat, W_lat]
5. y       = concat(msk, img_lat)              # [20, T, H_lat, W_lat]
6. context = concat(clip_fea, t5_ctx)          # [769, dim]
7. x_T     ~ N(0, I)                           # [16, T, H_lat, W_lat]

─── 采样循环（跑 25~50 次）─────────────────────────────
for t in schedule:  # e.g. [1.0, 0.96, ..., 0.0]
    # (可选) CFG: 各跑一次有/无条件
    x_in = concat(x_t, y)          # [36, T, H_lat, W_lat]
    v_cond = DiT(x_in, t, context, clip_fea, y)
    # v_uncond = DiT(x_in, t, empty_context, ...)
    # v = v_uncond + s · (v_cond - v_uncond)
    v = v_cond
    
    x_{t-Δt} = x_t - v · Δt        # flow matching 欧拉步

─── 解码 ──────────────────────────────────────────
video_latent = x_0                  # [16, T, H_lat, W_lat]
video        = VAE.decode(video_latent)  # [3, F, H, W]
</code></pre></div></div>

<p>几个细节：</p>

<ul>
  <li><strong><code class="language-plaintext highlighter-rouge">y</code> 只构造一次</strong>，在整个 25 步里都用同一份。因为参考图是不变的。</li>
  <li><strong>CFG（Classifier-Free Guidance）</strong>：Wan2.1 训练时会随机丢弃条件，所以推理时可以通过 <code class="language-plaintext highlighter-rouge">v = v_u + s·(v_c - v_u)</code> 放大条件信号（典型 <code class="language-plaintext highlighter-rouge">s=5~7.5</code>）。每步需要跑两遍 DiT。</li>
  <li><strong>首帧为什么保真？</strong>：因为第 0 帧的 <code class="language-plaintext highlighter-rouge">mask=1</code> 和 <code class="language-plaintext highlighter-rouge">VAE(img)</code> 一直被塞进输入，DiT 每步都在”被提醒”首帧应该长什么样。随着 t 变小，模型越来越相信这个约束。</li>
</ul>

<hr />

<h2 id="8-几个关键数字一张表带走">8. 几个关键数字一张表带走</h2>

<table>
  <thead>
    <tr>
      <th>参数</th>
      <th>值</th>
      <th>解释</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>VAE 空间 stride</td>
      <td>8</td>
      <td>H/W 方向下采样倍率</td>
    </tr>
    <tr>
      <td>VAE 时间 stride</td>
      <td>4</td>
      <td>F 方向下采样倍率</td>
    </tr>
    <tr>
      <td>VAE latent 通道</td>
      <td>16</td>
      <td>压缩后的通道数</td>
    </tr>
    <tr>
      <td>I2V <code class="language-plaintext highlighter-rouge">y</code> 通道</td>
      <td><strong>20</strong></td>
      <td>4 (mask) + 16 (VAE latent)</td>
    </tr>
    <tr>
      <td>DiT 输入通道</td>
      <td><strong>36</strong></td>
      <td>16 (noise) + 20 (y)</td>
    </tr>
    <tr>
      <td>Patch size</td>
      <td>(1, 2, 2)</td>
      <td>时间不并、空间 2×2</td>
    </tr>
    <tr>
      <td>文本 token 数</td>
      <td>512</td>
      <td>umT5 输出 padded</td>
    </tr>
    <tr>
      <td>CLIP token 数</td>
      <td><strong>257</strong></td>
      <td>1 CLS + 16×16 patches</td>
    </tr>
    <tr>
      <td>CLIP 维度</td>
      <td>1280</td>
      <td>ViT-H 的 hidden</td>
    </tr>
    <tr>
      <td>DiT hidden</td>
      <td>2048 (1.3B) / 更大 (14B)</td>
      <td> </td>
    </tr>
    <tr>
      <td>DiT 层数</td>
      <td>32 / 40</td>
      <td> </td>
    </tr>
    <tr>
      <td>注意力头</td>
      <td>16</td>
      <td>head_dim=128</td>
    </tr>
    <tr>
      <td>Sampling 步数</td>
      <td>25–50</td>
      <td>Flow Matching 下</td>
    </tr>
  </tbody>
</table>

<hr />

<h2 id="9-t2v-vs-i2v到底改了哪里">9. T2V vs I2V：到底改了哪里</h2>

<p>最后来一张对比表，帮你一眼看清两种模型的差别：</p>

<table>
  <thead>
    <tr>
      <th>方面</th>
      <th>T2V</th>
      <th>I2V</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>输入条件</td>
      <td>只有文本</td>
      <td>文本 + 参考图</td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">patch_embedding</code> in_channels</td>
      <td>16</td>
      <td><strong>36</strong></td>
    </tr>
    <tr>
      <td>Cross-Attention 类型</td>
      <td>单流（只有文本 K/V）</td>
      <td><strong>双流</strong>（image K_img/V_img + text K/V）</td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">img_emb</code> (CLIP → dim MLP)</td>
      <td>❌ 无</td>
      <td>✅ 有</td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">y</code>（mask + image latent）</td>
      <td>❌ 无</td>
      <td>✅ 有，通道拼接到 <code class="language-plaintext highlighter-rouge">x</code></td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">clip_fea</code></td>
      <td>❌ 无</td>
      <td>✅ 前置到 context</td>
    </tr>
    <tr>
      <td>采样过程</td>
      <td>一样（flow matching）</td>
      <td>一样</td>
    </tr>
  </tbody>
</table>

<p>所以 T2V → I2V 的改造量其实并不大：<strong>多了两条图像通路（CLIP 语义 + VAE 像素），外加一组额外的 cross-attention K/V 权重</strong>，其它骨架完全一致。这也是为什么很多团队能从 T2V checkpoint 微调出 I2V 版本。</p>

<hr />

<h2 id="10-常见疑问答疑">10. 常见疑问答疑</h2>

<p><strong>Q1：为什么不只用 CLIP、不要 VAE latent？</strong><br />
只用 CLIP 的话，模型知道”这是一只猫”，但不知道”这只猫在图里具体长什么样、坐在什么位置、毛色分布怎样”。CLIP 太高层。VAE latent 保留了像素级结构，所以首帧能做到”几乎像素级一致”。</p>

<p><strong>Q2：为什么不只用 VAE latent、不要 CLIP？</strong><br />
VAE latent 是”为了重建像素而设计的压缩特征”，它缺乏跨模态语义。CLIP 的语义向量能让模型在后续帧里理解”这张图在讲什么”，从而和 prompt 对齐得更好。两者是语义和像素的两极，缺一不可。</p>

<p><strong>Q3：mask 通道为什么要 4 维，不能是 1 维？</strong><br />
因为 VAE 的时间 stride = 4，一个 latent 帧对应 4 个像素帧。4 通道的 mask 让每个 latent 帧能独立标记”这 4 帧里各自是不是已知”。这样在滚动生成或多帧条件 I2V 里能无缝扩展。</p>

<p><strong>Q4：为什么 cross-attn 不做 RoPE？</strong><br />
RoPE 是为 query/key 在同一个坐标系下的相对距离准备的。cross-attn 的 key 来自外部序列（文本/图像 token），没有和视频 token 共享的”时空坐标”，用 RoPE 反而有害。</p>

<p><strong>Q5：CFG 在 I2V 里到底丢的是什么？</strong><br />
Wan2.1 做 CFG 时通常<strong>只丢文本</strong>（把 <code class="language-plaintext highlighter-rouge">t5_ctx</code> 置空），保留 CLIP 和 VAE latent。因为 I2V 的核心约束是参考图，不能丢；被用来”放大信号”的是文本 prompt。有些实现也会同时丢 CLIP，做”image guidance”。</p>

<p><strong>Q6：能不能做多图 / 多首尾帧条件？</strong><br />
可以。<code class="language-plaintext highlighter-rouge">y</code> 的结构天然支持——只需要把对应帧位置的 mask 设为 1、在 VAE 输入里把那些帧填真实图像即可。这就是社区里各种”首尾帧控制”、”关键帧插值”玩法的实现基础。</p>

<hr />

<h2 id="11-总结">11. 总结</h2>

<p>回到开头那张大图，现在你应该能一眼看懂每一块发生了什么：</p>

<ul>
  <li><strong>VAE</strong> 负责压缩像素和还原像素；</li>
  <li><strong>T5</strong> 负责理解文字；</li>
  <li><strong>CLIP</strong> 负责理解图像的”长相和风格”；</li>
  <li><strong>DiT</strong> 在一个压缩的 latent 空间里，一步一步把噪声拉回视频，拉的方向由前三个模块的条件决定；</li>
  <li><strong>I2V</strong> 的所有”魔法”就是把参考图的信息<strong>同时</strong>从两条通路（像素 / 语义）塞给 DiT，再用 cross-attention 双流、AdaLN 门控把它们融合进每个视频 token。</li>
</ul>

<p>一旦把这张图想清楚，你去读 Wan2.1 源码、甚至去扩展它（做首尾帧、多图参考、风格迁移），都会容易很多。</p>

<hr />

<h2 id="sources">Sources</h2>

<ul>
  <li><a href="https://github.com/Wan-Video/Wan2.1">Wan2.1 官方仓库</a></li>
  <li><a href="https://arxiv.org/abs/2503.20314">Wan 技术报告 arXiv:2503.20314</a></li>
  <li><a href="https://huggingface.co/spaces/2chch/Wan2.1/blob/2a07a689c8837aac720a73915b783ea23b371927/wan/modules/model.py">Wan2.1 model.py 源码（HF 镜像）</a></li>
  <li><a href="https://github.com/Wan-Video/Wan2.1/blob/main/wan/image2video.py">Wan2.1 image2video.py</a></li>
  <li><a href="https://arxiv.org/abs/2212.09748">DiT: Scalable Diffusion Models with Transformers (Peebles &amp; Xie, 2022)</a></li>
  <li><a href="https://arxiv.org/abs/2210.02747">Flow Matching for Generative Modeling (Lipman et al., 2023)</a></li>
  <li><a href="https://arxiv.org/abs/2209.03003">Rectified Flow (Liu et al., 2022)</a></li>
  <li><a href="https://arxiv.org/abs/2103.00020">CLIP (Radford et al., 2021)</a></li>
  <li><a href="https://arxiv.org/abs/2104.09864">RoFormer: RoPE (Su et al., 2021)</a></li>
</ul>]]></content><author><name>李勇志 (Yongzhi Li)</name><email>yongzhili@pku.edu.cn</email></author><category term="blog" /><category term="diffusion" /><category term="video generation" /><category term="image-to-video" /><category term="DiT" /><category term="multimodal" /><summary type="html"><![CDATA[最近视频生成模型卷得很快，Wan2.1 是阿里 Wan 团队开源的那一套。它最常用的场景之一就是 I2V（Image-to-Video）：给一张参考图加一句文字 prompt，模型给你生成一段几秒的视频，首帧基本还是那张图，后续的镜头就按你写的文字去演。]]></summary></entry><entry><title type="html">大模型面试手撕题全攻略：Attention、Transformer、归一化与损失函数</title><link href="https://liyongzhi.xyz/posts/2026/04/llm-interview-implementations/" rel="alternate" type="text/html" title="大模型面试手撕题全攻略：Attention、Transformer、归一化与损失函数" /><published>2026-04-22T00:00:00+08:00</published><updated>2026-04-22T00:00:00+08:00</updated><id>https://liyongzhi.xyz/posts/2026/04/blog-post-llm-interview-implementations</id><content type="html" xml:base="https://liyongzhi.xyz/posts/2026/04/llm-interview-implementations/"><![CDATA[<blockquote>
  <p>大模型算法岗面试中，手撕代码是几乎绕不过去的一环。面试官会盯着你从零实现 Attention、MHA、GQA、LayerNorm、RMSNorm、SafeSoftmax、Cross-Entropy 等模块，既考察你对原理的理解，也考察你是否能在紧张的环境下把数值稳定性、维度对齐、broadcasting 这些细节处理干净。</p>

  <p>这篇文章把这些高频手撕题系统梳理一遍：每一节都给出<strong>核心原理 → 数学公式 → 从零手写的 PyTorch 实现 → 面试容易追问的点</strong>，读完之后这一类题你应该都能在白板上 10 分钟内写出来。</p>
</blockquote>

<hr />

<h2 id="1-self-attention所有-transformer-的起点">1. Self-Attention：所有 Transformer 的起点</h2>

<h3 id="11-核心思想">1.1 核心思想</h3>

<p>Self-Attention 要回答的问题非常简单：</p>

<blockquote>
  <p>给定一个序列里的每个 token，它应该从其它 token 里”抄”多少信息过来？</p>
</blockquote>

<p>它的三件套是 <code class="language-plaintext highlighter-rouge">Query</code>、<code class="language-plaintext highlighter-rouge">Key</code>、<code class="language-plaintext highlighter-rouge">Value</code>：</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">Query</code>：当前 token 想”找什么”</li>
  <li><code class="language-plaintext highlighter-rouge">Key</code>：其它 token”能提供什么”（用来被检索）</li>
  <li><code class="language-plaintext highlighter-rouge">Value</code>：其它 token”真正要传递的内容”</li>
</ul>

<p>计算流程就一句话：<strong>Query 和 Key 做点积得到相似度，softmax 归一化后再加权求和 Value</strong>。</p>

<h3 id="12-数学公式">1.2 数学公式</h3>

<p>标准的 Scaled Dot-Product Attention：</p>

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V\]

<p>其中：</p>

<ul>
  <li>$Q \in \mathbb{R}^{n \times d_k}$，$K \in \mathbb{R}^{m \times d_k}$，$V \in \mathbb{R}^{m \times d_v}$</li>
  <li>$n$ 是 query 序列长度，$m$ 是 key/value 序列长度</li>
  <li>$d_k$ 是每个头的维度</li>
</ul>

<p><strong>为什么要除以 $\sqrt{d_k}$？</strong></p>

<p>当 $d_k$ 很大时，$QK^\top$ 的方差会随着 $d_k$ 线性增长。点积数值过大会让 softmax 落到极端区域——梯度趋近 0，模型训不动。除以 $\sqrt{d_k}$ 可以把方差拉回 $O(1)$ 的量级。</p>

<p>简单推导：假设 $q$ 和 $k$ 每个分量均值 0、方差 1 且独立，那么</p>

\[\text{Var}(q \cdot k) = \text{Var}\left(\sum_{i=1}^{d_k} q_i k_i\right) = d_k\]

<p>所以除以 $\sqrt{d_k}$ 后方差变回 1。</p>

<h3 id="13-从零手撕实现">1.3 从零手撕实现</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="kn">import</span> <span class="nn">math</span>

<span class="k">def</span> <span class="nf">scaled_dot_product_attention</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
    <span class="s">"""
    Q: [B, n, d_k]
    K: [B, m, d_k]
    V: [B, m, d_v]
    mask: [B, n, m]  True 的位置会被屏蔽
    """</span>
    <span class="n">d_k</span> <span class="o">=</span> <span class="n">Q</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="c1"># [B, n, m]
</span>    <span class="n">scores</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">d_k</span><span class="p">)</span>

    <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
        <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="p">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s">'-inf'</span><span class="p">))</span>

    <span class="n">attn</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>       <span class="c1"># [B, n, m]
</span>    <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">attn</span><span class="p">,</span> <span class="n">V</span><span class="p">)</span>            <span class="c1"># [B, n, d_v]
</span>    <span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="n">attn</span>
</code></pre></div></div>

<h3 id="14-面试常见追问">1.4 面试常见追问</h3>

<ul>
  <li>
    <p><strong>Q：为什么用点积而不是加性注意力（Bahdanau）？</strong>
点积可以用矩阵乘法高效实现，GPU 友好；加性注意力多一层线性变换，不利于大规模并行。</p>
  </li>
  <li>
    <p><strong>Q：mask 怎么做？</strong>
Decoder 的 causal mask 是一个上三角为 True 的矩阵；Padding mask 是把 padding 位置置为 True。二者常合并使用。</p>
  </li>
  <li>
    <p><strong>Q：为什么要用 softmax 而不是别的归一化？</strong>
softmax 保证权重非负且和为 1，符合”加权平均”的语义；同时它是可导的。</p>
  </li>
</ul>

<hr />

<h2 id="2-multi-head-attention-mha">2. Multi-Head Attention (MHA)</h2>

<h3 id="21-为什么要多头">2.1 为什么要多头？</h3>

<p>单头注意力只能学到一种”相似度”模式。多头允许模型在<strong>不同的子空间里关注不同类型的关系</strong>——比如一个头学句法，一个头学语义，一个头学远距离依赖。</p>

<h3 id="22-数学公式">2.2 数学公式</h3>

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O\]

<p>其中每个头独立计算：</p>

\[\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\]

<p>维度上：</p>

<ul>
  <li>输入维度 $d_{\text{model}}$，头数 $h$，每头维度 $d_k = d_{\text{model}} / h$</li>
  <li>$W_i^Q, W_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}$</li>
  <li>$W^O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}$</li>
</ul>

<p>注意：<strong>参数总量不变</strong>——切成 $h$ 个头之后每个头更”瘦”，但拼起来总维度还是 $d_{\text{model}}$。</p>

<h3 id="23-从零手撕实现">2.3 从零手撕实现</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">MultiHeadAttention</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.0</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="k">assert</span> <span class="n">d_model</span> <span class="o">%</span> <span class="n">num_heads</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s">"d_model 必须能被 num_heads 整除"</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">d_model</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">//</span> <span class="n">num_heads</span>

        <span class="c1"># 一次性投影，效率更高
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">W_q</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_k</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_v</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_o</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="s">"""
        q: [B, n, d_model]
        k, v: [B, m, d_model]
        mask: [B, 1, n, m] 或 [B, n, m]
        """</span>
        <span class="n">B</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">q</span><span class="p">.</span><span class="n">shape</span>
        <span class="n">m</span> <span class="o">=</span> <span class="n">k</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>

        <span class="c1"># 1. 线性投影
</span>        <span class="n">Q</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_q</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>  <span class="c1"># [B, n, d_model]
</span>        <span class="n">K</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_k</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
        <span class="n">V</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_v</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>

        <span class="c1"># 2. 切分成多头: [B, n, d_model] -&gt; [B, h, n, d_k]
</span>        <span class="n">Q</span> <span class="o">=</span> <span class="n">Q</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">K</span> <span class="o">=</span> <span class="n">K</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="c1"># [B, h, m, d_k]
</span>        <span class="n">V</span> <span class="o">=</span> <span class="n">V</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="c1"># [B, h, m, d_k]
</span>
        <span class="c1"># 3. Scaled Dot-Product Attention
</span>        <span class="n">scores</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">)</span>
        <span class="c1"># scores: [B, h, n, m]
</span>        <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
            <span class="k">if</span> <span class="n">mask</span><span class="p">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
                <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>   <span class="c1"># 广播到 head 维
</span>            <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="p">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s">'-inf'</span><span class="p">))</span>

        <span class="n">attn</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># [B,h,n,m]
</span>        <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span>
        <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">attn</span><span class="p">,</span> <span class="n">V</span><span class="p">)</span>        <span class="c1"># attn [B, h, n, m] V: [B, h, m, d_k] -&gt; [B, h, n, d_k]
</span>
        <span class="c1"># 4. 拼回来: [B, h, n, d_k] -&gt; [B, n, d_model]
</span>        <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_model</span><span class="p">)</span>

        <span class="c1"># 5. 输出投影
</span>        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_o</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="24-面试常见追问">2.4 面试常见追问</h3>

<ul>
  <li>
    <p><strong>Q：头数 h 越多越好吗？</strong>
不一定。头数过多会让每个头的 $d_k$ 过小，表达能力反而下降；同时 KV Cache 也会膨胀。实际工程通常 32~64 个头。</p>
  </li>
  <li>
    <p><strong>Q：MHA 的时间复杂度？</strong>
$O(n^2 \cdot d_{\text{model}})$，序列长度是瓶颈。这也是 Flash Attention、线性 Attention 等工作的优化对象。</p>
  </li>
</ul>

<hr />

<h2 id="3-grouped-query-attention-gqa">3. Grouped-Query Attention (GQA)</h2>

<h3 id="31-为什么要-gqa">3.1 为什么要 GQA？</h3>

<p>推理阶段的显存瓶颈主要是 <strong>KV Cache</strong>——每一步解码都要存下所有历史 token 的 K 和 V。</p>

<ul>
  <li>MHA：每个 Q 头都有独立的 K、V 头，KV Cache = $2 \cdot n \cdot h \cdot d_k$</li>
  <li>MQA（Multi-Query Attention）：所有 Q 头共享一组 K、V，KV Cache 缩小 $h$ 倍，但质量下降</li>
  <li><strong>GQA</strong>：折中方案——Q 头分成 $g$ 组，每组共享一组 K、V</li>
</ul>

<p>当 $g = h$ 就是 MHA，当 $g = 1$ 就是 MQA。LLaMA-2/3、Mixtral 等主流开源模型都用的是 GQA。</p>

<p><br />
<img align="center" width="800" src="https://liyongzhi.xyz/images/posts/gqa-comparison.svg" alt="MHA vs GQA vs MQA comparison" />
<br /></p>

<h3 id="32-数学形式">3.2 数学形式</h3>

<p>设 Q 头数为 $h$，KV 头组数为 $g$，每组包含 $h / g$ 个 Q 头共享同一对 K、V：</p>

\[\text{head}_i = \text{Attention}(Q_i, K_{\lfloor i / (h/g) \rfloor}, V_{\lfloor i / (h/g) \rfloor})\]

<h3 id="33-从零手撕实现">3.3 从零手撕实现</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">GroupedQueryAttention</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">num_kv_groups</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.0</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="k">assert</span> <span class="n">d_model</span> <span class="o">%</span> <span class="n">num_heads</span> <span class="o">==</span> <span class="mi">0</span>
        <span class="k">assert</span> <span class="n">num_heads</span> <span class="o">%</span> <span class="n">num_kv_groups</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s">"头数必须能被组数整除"</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">d_model</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_kv_groups</span> <span class="o">=</span> <span class="n">num_kv_groups</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">//</span> <span class="n">num_heads</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">group_size</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="o">//</span> <span class="n">num_kv_groups</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">W_q</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">num_heads</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="c1"># K, V 只用组数个头
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">W_k</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">num_kv_groups</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_v</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">num_kv_groups</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_o</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="n">B</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">q</span><span class="p">.</span><span class="n">shape</span>
        <span class="n">m</span> <span class="o">=</span> <span class="n">k</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>

        <span class="c1"># Q: [B, h,  n, d_k]
</span>        <span class="c1"># K,V: [B, g, m, d_k]
</span>        <span class="n">Q</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_q</span><span class="p">(</span><span class="n">q</span><span class="p">).</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">K</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_k</span><span class="p">(</span><span class="n">k</span><span class="p">).</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_kv_groups</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">V</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_v</span><span class="p">(</span><span class="n">v</span><span class="p">).</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">m</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_kv_groups</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>

        <span class="c1"># 把 K, V 在头维度上"重复 group_size 次"对齐到 Q
</span>        <span class="c1"># [B, g, m, d_k] -&gt; [B, h, m, d_k]
</span>        <span class="n">K</span> <span class="o">=</span> <span class="n">K</span><span class="p">.</span><span class="n">repeat_interleave</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">group_size</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">V</span> <span class="o">=</span> <span class="n">V</span><span class="p">.</span><span class="n">repeat_interleave</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">group_size</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

        <span class="n">scores</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
            <span class="k">if</span> <span class="n">mask</span><span class="p">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
                <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
            <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="p">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s">'-inf'</span><span class="p">))</span>

        <span class="n">attn</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span>
        <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">attn</span><span class="p">,</span> <span class="n">V</span><span class="p">)</span>

        <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_model</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_o</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
</code></pre></div></div>

<p><strong>实现细节</strong>：<code class="language-plaintext highlighter-rouge">repeat_interleave</code> 在工程上其实可以避免——直接用 einsum 或在 attention 计算里广播更省显存。但面试场景下写 <code class="language-plaintext highlighter-rouge">repeat_interleave</code> 更直观易读。</p>

<h3 id="34-面试常见追问">3.4 面试常见追问</h3>

<ul>
  <li>
    <p><strong>Q：GQA 为什么比 MQA 效果好？</strong>
MQA 所有 Q 头共享一套 K、V，表达能力瓶颈明显；GQA 保留了多组 K、V，能在”显存”和”表达”之间做权衡。</p>
  </li>
  <li>
    <p><strong>Q：KV Cache 具体省多少？</strong>
对 LLaMA-2-70B，MHA KV Cache 每层是 $2 \cdot 64 \cdot d_k$，GQA 是 $2 \cdot 8 \cdot d_k$，直接省 8 倍。</p>
  </li>
</ul>

<hr />

<h2 id="4-transformer-encoder-模块">4. Transformer Encoder 模块</h2>

<h3 id="41-结构图">4.1 结构图</h3>

<p>一个完整的 Transformer Encoder Block 由以下部分组成：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>x ──▶ LayerNorm ──▶ MHA ──▶ + ──▶ LayerNorm ──▶ FFN ──▶ + ──▶ out
  │                          ▲                           ▲
  └──────────────────────────┘                           │
                │                                         │
                └─────────────────────────────────────────┘
</code></pre></div></div>

<p>这是 Pre-Norm 结构（GPT、LLaMA 都用这种）。原版 Transformer 用 Post-Norm，训练稳定性差一些，现在主流都切到了 Pre-Norm。</p>

<h3 id="42-公式">4.2 公式</h3>

\[\begin{aligned}
z &amp;= x + \text{MHA}(\text{LN}(x)) \\
y &amp;= z + \text{FFN}(\text{LN}(z))
\end{aligned}\]

<p>其中 FFN 通常是：</p>

\[\text{FFN}(x) = W_2 \cdot \text{GELU}(W_1 x + b_1) + b_2\]

<p>中间维度一般取 $4 \times d_{\text{model}}$。</p>

<h3 id="43-从零手撕实现">4.3 从零手撕实现</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">FeedForward</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.0</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">fc2</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">F</span><span class="p">.</span><span class="n">gelu</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">))))</span>


<span class="k">class</span> <span class="nc">TransformerEncoderBlock</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.0</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">ln1</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">dropout</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">ln2</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">ffn</span> <span class="o">=</span> <span class="n">FeedForward</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">dropout</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="c1"># Pre-Norm + Residual
</span>        <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">ln1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">attn</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">))</span>

        <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">ln2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">ffn</span><span class="p">(</span><span class="n">h</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">x</span>
</code></pre></div></div>

<h3 id="44-面试常见追问">4.4 面试常见追问</h3>

<ul>
  <li>
    <p><strong>Q：Pre-Norm 和 Post-Norm 的区别？</strong>
Post-Norm（原版）：$y = \text{LN}(x + \text{Sublayer}(x))$，残差链路被 LN 截断，深层训练不稳定。
Pre-Norm：$y = x + \text{Sublayer}(\text{LN}(x))$，残差是”干净”的恒等映射，梯度更稳，但可能轻微损失表达能力。</p>
  </li>
  <li>
    <p><strong>Q：FFN 中间维度为什么是 4 倍？</strong>
经验值。它提供了模型的主要参数量（约占 2/3），也是存储世界知识的主要场所。</p>
  </li>
  <li>
    <p><strong>Q：GELU 和 ReLU 的区别？</strong>
GELU 是 $x \cdot \Phi(x)$，平滑版 ReLU，负半轴有小幅激活，在语言模型上效果更好。LLaMA 进一步用 SwiGLU。</p>
  </li>
</ul>

<hr />

<h2 id="5-layernorm-vs-batchnorm">5. LayerNorm vs BatchNorm</h2>

<h3 id="51-区别一句话">5.1 区别一句话</h3>

<ul>
  <li><strong>BatchNorm</strong>：对每个<strong>特征维度</strong>，在 <strong>batch 维度</strong>上算均值和方差</li>
  <li><strong>LayerNorm</strong>：对每个<strong>样本</strong>，在<strong>特征维度</strong>上算均值和方差</li>
</ul>

<h3 id="52-公式">5.2 公式</h3>

<p>设输入 $x \in \mathbb{R}^{B \times L \times D}$（Batch × SeqLen × Dim）。</p>

<p><strong>BatchNorm</strong>（对 NLP 几乎不用）：</p>

\[\mu_d = \frac{1}{B \cdot L} \sum_{b, l} x_{b, l, d}, \quad
\sigma_d^2 = \frac{1}{B \cdot L} \sum_{b, l} (x_{b, l, d} - \mu_d)^2\]

\[\hat{x}_{b, l, d} = \frac{x_{b, l, d} - \mu_d}{\sqrt{\sigma_d^2 + \epsilon}}, \quad
y = \gamma \hat{x} + \beta\]

<p><strong>LayerNorm</strong>：</p>

\[\mu_{b, l} = \frac{1}{D} \sum_d x_{b, l, d}, \quad
\sigma_{b, l}^2 = \frac{1}{D} \sum_d (x_{b, l, d} - \mu_{b, l})^2\]

\[\hat{x}_{b, l, d} = \frac{x_{b, l, d} - \mu_{b, l}}{\sqrt{\sigma_{b, l}^2 + \epsilon}}, \quad
y = \gamma \hat{x} + \beta\]

<p>注意 LayerNorm 的 $\gamma, \beta$ 是 $D$ 维向量，不依赖 batch 和 seq。</p>

<h3 id="53-为什么-nlp-用-layernorm-而不是-batchnorm">5.3 为什么 NLP 用 LayerNorm 而不是 BatchNorm？</h3>

<ol>
  <li><strong>变长序列</strong>：NLP 输入有大量 padding，padding 位置参与 BN 统计会污染结果。</li>
  <li><strong>小 batch</strong>：语言模型 batch 常常不大（长序列更吃显存），BN 在小 batch 上统计量不稳。</li>
  <li><strong>训练/推理一致</strong>：BN 推理时用 running mean/var，语言模型分布漂移敏感；LN 训推完全一致。</li>
  <li><strong>每个 token 独立归一化</strong>更贴合语言模型”逐 token 建模”的直觉。</li>
</ol>

<h3 id="54-从零手撕实现">5.4 从零手撕实现</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">LayerNorm</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">gamma</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">dim</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">dim</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># x: [..., dim]
</span>        <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="n">var</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">var</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">unbiased</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="n">x_hat</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">eps</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">gamma</span> <span class="o">*</span> <span class="n">x_hat</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">beta</span>


<span class="k">class</span> <span class="nc">BatchNorm1d</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">gamma</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">num_features</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">num_features</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s">'running_mean'</span><span class="p">,</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">num_features</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s">'running_var'</span><span class="p">,</span> <span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">num_features</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentum</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># x: [B, D]
</span>        <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">training</span><span class="p">:</span>
            <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
            <span class="n">var</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">var</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">unbiased</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
            <span class="c1"># 更新 running 统计量
</span>            <span class="bp">self</span><span class="p">.</span><span class="n">running_mean</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="p">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">running_mean</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">mean</span><span class="p">.</span><span class="n">detach</span><span class="p">()</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">running_var</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="p">.</span><span class="n">momentum</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">running_var</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">momentum</span> <span class="o">*</span> <span class="n">var</span><span class="p">.</span><span class="n">detach</span><span class="p">()</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">mean</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">running_mean</span>
            <span class="n">var</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">running_var</span>

        <span class="n">x_hat</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">torch</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">eps</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">gamma</span> <span class="o">*</span> <span class="n">x_hat</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">beta</span>
</code></pre></div></div>

<p><strong>注意</strong>：<code class="language-plaintext highlighter-rouge">unbiased=False</code> 表示用 $1/N$ 而不是 $1/(N-1)$，这是神经网络里常用的有偏估计，和 PyTorch 默认实现一致。</p>

<hr />

<h2 id="6-rmsnormllama-时代的归一化标配">6. RMSNorm：LLaMA 时代的归一化标配</h2>

<h3 id="61-动机">6.1 动机</h3>

<p>LayerNorm 做了两件事：<strong>减均值</strong>（中心化） + <strong>除标准差</strong>（缩放）。
但有研究发现，<strong>减均值这一步对性能的贡献非常小</strong>——真正起作用的是”缩放”。</p>

<p>于是 RMSNorm 提出：直接去掉减均值，只做缩放，用 RMS（Root Mean Square）替代标准差：</p>

<ul>
  <li>省掉一次均值计算和减法</li>
  <li>实测 7%~64% 的速度提升</li>
  <li>效果和 LayerNorm 持平甚至更好</li>
</ul>

<p>LLaMA、LLaMA-2/3、Mistral、Qwen 等主流模型全都用 RMSNorm。</p>

<h3 id="62-公式">6.2 公式</h3>

\[\text{RMS}(x) = \sqrt{\frac{1}{D} \sum_{d=1}^{D} x_d^2}\]

\[\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x) + \epsilon} \odot \gamma\]

<p>注意：</p>

<ul>
  <li>没有 $\beta$（平移项）</li>
  <li>分母里是 RMS 而不是 std</li>
  <li>$\gamma$ 是可学习的缩放向量</li>
</ul>

<h3 id="63-和-layernorm-的关系">6.3 和 LayerNorm 的关系</h3>

<p>如果 $x$ 的均值恰好为 0，那么 $\text{RMS}(x) = \text{std}(x)$，RMSNorm 就退化为没有 bias 的 LayerNorm。</p>

<p>换句话说，<strong>RMSNorm = LayerNorm 扔掉均值平移</strong>。</p>

<h3 id="64-从零手撕实现">6.4 从零手撕实现</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">RMSNorm</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">gamma</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">dim</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># x: [..., dim]
</span>        <span class="c1"># 注意用 float32 算 rms 避免 fp16 下溢
</span>        <span class="n">rms</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="nb">float</span><span class="p">().</span><span class="nb">pow</span><span class="p">(</span><span class="mi">2</span><span class="p">).</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">).</span><span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">eps</span><span class="p">).</span><span class="n">rsqrt</span><span class="p">()</span>
        <span class="k">return</span> <span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="nb">float</span><span class="p">()</span> <span class="o">*</span> <span class="n">rms</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">gamma</span>
</code></pre></div></div>

<p><strong>工程细节</strong>：</p>
<ul>
  <li><code class="language-plaintext highlighter-rouge">rsqrt()</code> 比 <code class="language-plaintext highlighter-rouge">1 / sqrt()</code> 更快</li>
  <li>混合精度下先 cast 到 fp32 计算，结果再 cast 回原 dtype，可以避免数值不稳定</li>
  <li>eps 通常取 $10^{-6}$ 或 $10^{-5}$</li>
</ul>

<h3 id="65-面试常见追问">6.5 面试常见追问</h3>

<ul>
  <li>
    <p><strong>Q：RMSNorm 为什么比 LayerNorm 快？</strong>
少了一次”求均值 + 做减法”的操作，访存和计算都减半。</p>
  </li>
  <li>
    <p><strong>Q：去掉 $\beta$ 不会损失表达能力吗？</strong>
理论上会，但实验发现对语言模型几乎无影响。可能原因是残差连接本身已经提供了足够的 bias 能力。</p>
  </li>
</ul>

<hr />

<h2 id="7-safe-softmax数值稳定性的必考点">7. Safe Softmax：数值稳定性的必考点</h2>

<h3 id="71-朴素-softmax-的问题">7.1 朴素 Softmax 的问题</h3>

\[\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}\]

<p>当 $x_i$ 很大时（比如 1000），$e^{x_i}$ 直接 <strong>上溢为 inf</strong>；当 $x_i$ 很小时，$e^{x_i}$ 下溢为 0。fp16/bf16 下这个问题尤其严重（fp16 的最大值只有 65504）。</p>

<h3 id="72-safe-softmax-技巧">7.2 Safe Softmax 技巧</h3>

<p>利用 softmax 的<strong>平移不变性</strong>：</p>

\[\text{softmax}(x_i) = \text{softmax}(x_i - c)\]

<p>因为分子分母同时乘 $e^{-c}$ 会约掉。所以我们可以取 $c = \max_j x_j$：</p>

\[\text{softmax}(x_i) = \frac{e^{x_i - \max_j x_j}}{\sum_j e^{x_j - \max_j x_j}}\]

<p>这样：</p>

<ul>
  <li>指数的最大值变成 $e^0 = 1$，永远不会上溢</li>
  <li>分母至少有一项是 1，不会下溢为 0</li>
</ul>

<h3 id="73-从零手撕实现">7.3 从零手撕实现</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">safe_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
    <span class="s">"""
    数值稳定的 softmax
    x: 任意 shape 的 tensor
    dim: 在哪个维度做 softmax
    """</span>
    <span class="c1"># 减去最大值防溢出（注意 detach/不 detach 都不影响梯度，因为是平移不变的）
</span>    <span class="n">x_max</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">).</span><span class="n">values</span>
    <span class="n">x_shifted</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span>

    <span class="n">exp_x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x_shifted</span><span class="p">)</span>
    <span class="n">sum_exp</span> <span class="o">=</span> <span class="n">exp_x</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">exp_x</span> <span class="o">/</span> <span class="n">sum_exp</span>


<span class="k">def</span> <span class="nf">safe_log_softmax</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">):</span>
    <span class="s">"""
    数值稳定的 log_softmax，后面算交叉熵要用到
    log_softmax(x) = x - max - log(sum(exp(x - max)))
    """</span>
    <span class="n">x_max</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">).</span><span class="n">values</span>
    <span class="n">x_shifted</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">x_max</span>
    <span class="c1"># log-sum-exp
</span>    <span class="n">log_sum_exp</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x_shifted</span><span class="p">).</span><span class="nb">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">x_shifted</span> <span class="o">-</span> <span class="n">log_sum_exp</span>
</code></pre></div></div>

<h3 id="74-面试常见追问">7.4 面试常见追问</h3>

<ul>
  <li>
    <p><strong>Q：减最大值这个操作会影响梯度吗？</strong>
不会。softmax 对平移不变，数学上完全等价，梯度也完全等价。</p>
  </li>
  <li>
    <p><strong>Q：如果所有输入都是 -inf（比如全被 mask）怎么办？</strong>
会出现 NaN（0/0）。实际实现里对 attention 会保证至少有一个非 mask 的位置，或者在 softmax 之后把整行置零。</p>
  </li>
  <li>
    <p><strong>Q：Flash Attention 里的 online softmax 是什么？</strong>
Flash Attention 在 tiling 过程中不能一次看到所有的 logits，它用 <strong>递推公式</strong> 维护当前已见的最大值和分母，逐块更新。这是 safe softmax 的分块在线版本。</p>
  </li>
</ul>

<hr />

<h2 id="8-cross-entropy-loss分类任务的灵魂">8. Cross-Entropy Loss：分类任务的灵魂</h2>

<h3 id="81-公式">8.1 公式</h3>

<p>对于 $C$ 分类问题，模型输出 logits $z \in \mathbb{R}^C$，真实标签 $y \in {0, 1, \dots, C-1}$。</p>

<p>先 softmax 得到概率：</p>

\[p_i = \frac{e^{z_i}}{\sum_j e^{z_j}}\]

<p>交叉熵损失：</p>

\[\mathcal{L} = -\log p_y = -\log \frac{e^{z_y}}{\sum_j e^{z_j}} = -z_y + \log \sum_j e^{z_j}\]

<p>最右边的形式就是我们常说的 <strong>LogSumExp</strong> 形式，非常适合数值稳定实现。</p>

<p><strong>批量版</strong>：</p>

\[\mathcal{L} = -\frac{1}{N} \sum_{n=1}^{N} \log p_{n, y_n}\]

<h3 id="82-为什么要用交叉熵而不是-mse">8.2 为什么要用交叉熵而不是 MSE？</h3>

<ol>
  <li><strong>梯度性质好</strong>：对 logits 求导，$\frac{\partial \mathcal{L}}{\partial z_i} = p_i - \mathbb{1}[i = y]$，形式简洁且不会出现”梯度消失”。</li>
  <li><strong>概率语义契合</strong>：交叉熵衡量两个分布的距离，和 softmax 输出的”概率”天然配对。</li>
  <li><strong>MSE 配 softmax 会梯度饱和</strong>：预测很离谱时梯度反而很小，训练慢。</li>
</ol>

<h3 id="83-从零手撕实现">8.3 从零手撕实现</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">cross_entropy_loss</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s">'mean'</span><span class="p">):</span>
    <span class="s">"""
    logits: [N, C]  未经 softmax 的原始输出
    targets: [N]    每个样本的真实标签 index
    """</span>
    <span class="c1"># 用 log_softmax 的形式，避免 log(0)
</span>    <span class="n">log_probs</span> <span class="o">=</span> <span class="n">safe_log_softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># [N, C]
</span>
    <span class="c1"># gather 出真实标签对应的 log_prob
</span>    <span class="c1"># log_probs.gather(1, targets.unsqueeze(1)).squeeze(1): [N]
</span>    <span class="n">nll</span> <span class="o">=</span> <span class="o">-</span><span class="n">log_probs</span><span class="p">.</span><span class="n">gather</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">targets</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)).</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>

    <span class="k">if</span> <span class="n">reduction</span> <span class="o">==</span> <span class="s">'mean'</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">nll</span><span class="p">.</span><span class="n">mean</span><span class="p">()</span>
    <span class="k">elif</span> <span class="n">reduction</span> <span class="o">==</span> <span class="s">'sum'</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">nll</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">nll</span>
</code></pre></div></div>

<h3 id="84-带-label-smoothing-的版本">8.4 带 Label Smoothing 的版本</h3>

<p>大模型训练里经常用到 Label Smoothing：不把真实标签当成 one-hot，而是留一点点给别的类别，防止模型”过分自信”。</p>

\[\tilde{y}_i = \begin{cases}
1 - \epsilon + \epsilon / C &amp; i = y \\
\epsilon / C &amp; i \neq y
\end{cases}\]

\[\mathcal{L} = -\sum_{i} \tilde{y}_i \log p_i\]

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">cross_entropy_with_label_smoothing</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">smoothing</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
    <span class="n">N</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">logits</span><span class="p">.</span><span class="n">shape</span>
    <span class="n">log_probs</span> <span class="o">=</span> <span class="n">safe_log_softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

    <span class="c1"># 构造平滑后的标签分布
</span>    <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
        <span class="n">true_dist</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">full_like</span><span class="p">(</span><span class="n">log_probs</span><span class="p">,</span> <span class="n">smoothing</span> <span class="o">/</span> <span class="n">C</span><span class="p">)</span>
        <span class="n">true_dist</span><span class="p">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">targets</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">smoothing</span> <span class="o">+</span> <span class="n">smoothing</span> <span class="o">/</span> <span class="n">C</span><span class="p">)</span>

    <span class="c1"># -(true_dist * log_probs).sum(dim=-1).mean()
</span>    <span class="k">return</span> <span class="o">-</span><span class="p">(</span><span class="n">true_dist</span> <span class="o">*</span> <span class="n">log_probs</span><span class="p">).</span><span class="nb">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">).</span><span class="n">mean</span><span class="p">()</span>
</code></pre></div></div>

<h3 id="85-面试常见追问">8.5 面试常见追问</h3>

<ul>
  <li>
    <p><strong>Q：PyTorch 的 <code class="language-plaintext highlighter-rouge">nn.CrossEntropyLoss</code> 输入是 logits 还是概率？</strong>
logits！它内部融合了 log_softmax + nll_loss，比手写分开两步更稳定、更快。</p>
  </li>
  <li>
    <p><strong>Q：对 logits 求导的结果推一下？</strong>
$\frac{\partial \mathcal{L}}{\partial z_i} = p_i - \mathbb{1}[i = y]$，这就是为什么反向传播时只需要”预测概率减去 one-hot”。</p>
  </li>
  <li>
    <p><strong>Q：语言模型训练时怎么处理 padding 的 loss？</strong>
用 <code class="language-plaintext highlighter-rouge">ignore_index</code>（PyTorch 原生支持），或者构造 mask 在 reduction 之前把 padding 位置的 loss 置零。</p>
  </li>
</ul>

<hr />

<h2 id="9-一个完整的白板样板">9. 一个完整的”白板样板”</h2>

<p>最后给一个浓缩版的”应急套路”——如果面试官让你 5 分钟手撕一个精简 Transformer Block，就按下面这个最小实现来：</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="kn">import</span> <span class="nn">math</span>


<span class="k">class</span> <span class="nc">RMSNorm</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">dim</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">eps</span> <span class="o">=</span> <span class="n">eps</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">rms</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="nb">pow</span><span class="p">(</span><span class="mi">2</span><span class="p">).</span><span class="n">mean</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">).</span><span class="n">add</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">eps</span><span class="p">).</span><span class="n">rsqrt</span><span class="p">()</span>
        <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">rms</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">weight</span>


<span class="k">class</span> <span class="nc">MHA</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">h</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span> <span class="o">=</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">//</span> <span class="n">n_heads</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">wq</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">wk</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">wv</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">wo</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="n">B</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">D</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">shape</span>
        <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">wq</span><span class="p">(</span><span class="n">x</span><span class="p">).</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">h</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">wk</span><span class="p">(</span><span class="n">x</span><span class="p">).</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">h</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">wv</span><span class="p">(</span><span class="n">x</span><span class="p">).</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">h</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>

        <span class="n">scores</span> <span class="o">=</span> <span class="n">q</span> <span class="o">@</span> <span class="n">k</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">d_k</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
            <span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="p">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s">'-inf'</span><span class="p">))</span>
        <span class="n">attn</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">out</span> <span class="o">=</span> <span class="p">(</span><span class="n">attn</span> <span class="o">@</span> <span class="n">v</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">D</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">wo</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>


<span class="k">class</span> <span class="nc">Block</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">n1</span> <span class="o">=</span> <span class="n">RMSNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">MHA</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">n2</span> <span class="o">=</span> <span class="n">RMSNorm</span><span class="p">(</span><span class="n">d_model</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">ffn</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
            <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">),</span> <span class="n">nn</span><span class="p">.</span><span class="n">GELU</span><span class="p">(),</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">attn</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">n1</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">mask</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">ffn</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">n2</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">x</span>
</code></pre></div></div>

<p>把这套板子背熟，再根据面试官追问填 GQA、KV Cache、RoPE 等变种即可。</p>

<hr />

<h2 id="10-总结">10. 总结</h2>

<p>本文把大模型算法岗最常考的一组手撕题串起来：</p>

<table>
  <thead>
    <tr>
      <th>模块</th>
      <th>一句话总结</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Self-Attention</td>
      <td>Q·Kᵀ / √d_k 再 softmax 加权 V</td>
    </tr>
    <tr>
      <td>MHA</td>
      <td>切成 h 个头并行算 Attention，拼回再投影</td>
    </tr>
    <tr>
      <td>GQA</td>
      <td>Q 独立 h 个头，KV 只有 g 组头，省 KV Cache</td>
    </tr>
    <tr>
      <td>Transformer Encoder</td>
      <td>Pre-Norm + MHA + Pre-Norm + FFN，两段残差</td>
    </tr>
    <tr>
      <td>LayerNorm</td>
      <td>每个样本在特征维上做归一化</td>
    </tr>
    <tr>
      <td>BatchNorm</td>
      <td>每个特征在 batch 维上做归一化（NLP 不用）</td>
    </tr>
    <tr>
      <td>RMSNorm</td>
      <td>LayerNorm 去掉均值平移，只保留 RMS 缩放</td>
    </tr>
    <tr>
      <td>Safe Softmax</td>
      <td>减去最大值再做 exp，避免上溢</td>
    </tr>
    <tr>
      <td>Cross-Entropy</td>
      <td>$-\log p_y$，实战用 log_softmax + nll_loss</td>
    </tr>
  </tbody>
</table>

<p>建议的学习路径：先把每一节的公式自己手推一遍，再白板默写实现，然后用 <code class="language-plaintext highlighter-rouge">torch.allclose</code> 和 <code class="language-plaintext highlighter-rouge">nn</code> 自带模块对拍一下数值，最后在纸上做时空复杂度分析。真正把这一整套走完之后，这类题目你都能在面试里淡定应付了。</p>

<hr />

<h2 id="参考资料">参考资料</h2>

<ol>
  <li>Vaswani et al., <em>Attention is All You Need</em>, 2017</li>
  <li>Ba et al., <em>Layer Normalization</em>, 2016</li>
  <li>Zhang &amp; Sennrich, <em>Root Mean Square Layer Normalization</em>, 2019</li>
  <li>Ainslie et al., <em>GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints</em>, 2023</li>
  <li>Touvron et al., <em>LLaMA: Open and Efficient Foundation Language Models</em>, 2023</li>
</ol>]]></content><author><name>李勇志 (Yongzhi Li)</name><email>yongzhili@pku.edu.cn</email></author><category term="blog" /><category term="llm" /><category term="transformer" /><category term="attention" /><category term="interview" /><category term="machine learning" /><summary type="html"><![CDATA[大模型算法岗面试中，手撕代码是几乎绕不过去的一环。面试官会盯着你从零实现 Attention、MHA、GQA、LayerNorm、RMSNorm、SafeSoftmax、Cross-Entropy 等模块，既考察你对原理的理解，也考察你是否能在紧张的环境下把数值稳定性、维度对齐、broadcasting 这些细节处理干净。 这篇文章把这些高频手撕题系统梳理一遍：每一节都给出核心原理 → 数学公式 → 从零手写的 PyTorch 实现 → 面试容易追问的点，读完之后这一类题你应该都能在白板上 10 分钟内写出来。]]></summary></entry><entry><title type="html">从 Classifier Guidance 到 Classifier-Free Guidance：一文讲清 Diffusion 里的 CFG</title><link href="https://liyongzhi.xyz/posts/2026/04/diffusion-cfg/" rel="alternate" type="text/html" title="从 Classifier Guidance 到 Classifier-Free Guidance：一文讲清 Diffusion 里的 CFG" /><published>2026-04-20T00:00:00+08:00</published><updated>2026-04-20T00:00:00+08:00</updated><id>https://liyongzhi.xyz/posts/2026/04/blog-post-diffusion-cfg</id><content type="html" xml:base="https://liyongzhi.xyz/posts/2026/04/diffusion-cfg/"><![CDATA[<p>Diffusion 模型发展到今天，<code class="language-plaintext highlighter-rouge">CFG</code> 几乎已经成了文本生成图像系统里的“默认组件”。<br />
但很多人第一次看到它时都会困惑：</p>

<ul>
  <li>为什么一个模型要做“有条件前向”一次、“无条件前向”一次？</li>
  <li>为什么不能直接训练一个很强的条件模型？</li>
  <li><code class="language-plaintext highlighter-rouge">Classifier Guidance</code> 和 <code class="language-plaintext highlighter-rouge">Classifier-Free Guidance</code> 到底是什么关系？</li>
  <li>为什么后来的蒸馏模型又常常说“已经不需要 CFG 了”？</li>
</ul>

<p>这篇文章想做的事情，就是把这条知识脉络从头理顺：<br />
从无条件 Diffusion 的起点出发，讲到 <code class="language-plaintext highlighter-rouge">Classifier Guidance</code>，再讲到今天真正主流的 <code class="language-plaintext highlighter-rouge">Classifier-Free Guidance (CFG)</code>，最后补上 few-step / distillation 路线和近两年的一些延伸工作。</p>

<p><br />
<img align="center" width="1000" src="https://liyongzhi.xyz/images/posts/diffusion-cfg-evolution.svg" alt="Diffusion guidance evolution timeline" />
<br /></p>

<p>上图可以先当作全文导航来读：<br />
最左边是无条件 diffusion 的起点，中间是两代 guidance，最右边是把 guidance 效果蒸进 few-step 模型的后续路线。</p>

<hr />

<h2 id="1-起点无条件-diffusion-学的只是-px">1. 起点：无条件 Diffusion 学的只是 <code class="language-plaintext highlighter-rouge">p(x)</code></h2>

<p>先从最原始的 DDPM 视角看问题。<br />
一个<strong>无条件</strong> diffusion 模型学的是数据分布 <code class="language-plaintext highlighter-rouge">p(x)</code>，也就是：</p>

<blockquote>
  <p>什么样的样本看起来像“真实世界中的自然图像”。</p>
</blockquote>

<p>它并不知道你想生成什么。<br />
所以如果你只给它高斯噪声，它学会的是“把噪声慢慢拉回自然图像流形上”，而不是“把噪声拉成一只猫”或者“拉成一辆红色跑车”。</p>

<p>在 DDPM 的常见参数化里，前向加噪可以写成：</p>

\[x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon,\quad \epsilon \sim \mathcal{N}(0, I)\]

<p>训练时模型通常预测噪声：</p>

\[\epsilon_\theta(x_t, t)\]

<p>Ho 等人在 DDPM 里说明了这种噪声预测与 denoising score matching 存在紧密联系，因此我们常把 diffusion 模型理解为在不同噪声水平下学习一个 score function 的近似器。[1]</p>

<p>更直观一点说：</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">x_t</code> 是某个时刻的噪声图</li>
  <li>模型要回答的问题是：“这张图里哪一部分是噪声？应该往哪个方向去噪？”</li>
  <li>如果模型只学了 <code class="language-plaintext highlighter-rouge">p(x)</code>，它最多只能说“往更像自然图像的方向去”</li>
</ul>

<p>但条件生成想做的是：</p>

\[p(x|c)\]

<p>也就是：</p>

<blockquote>
  <p>在满足条件 <code class="language-plaintext highlighter-rouge">c</code> 的前提下，什么样的图像是合理的？</p>
</blockquote>

<p>这里的 <code class="language-plaintext highlighter-rouge">c</code> 可以是类别标签、文本、另一张图像、语音 embedding，甚至更复杂的多模态条件。</p>

<p>问题于是变成了：</p>

<blockquote>
  <p>怎么把条件信号注入去噪过程？</p>
</blockquote>

<p>这条路上，Diffusion 社区先后给出了两代代表性答案。</p>

<hr />

<h2 id="2-第一代方案classifier-guidance">2. 第一代方案：Classifier Guidance</h2>

<p><code class="language-plaintext highlighter-rouge">Classifier Guidance</code> 的代表工作是 Dhariwal 和 Nichol 在 2021 年的 <em>Diffusion Models Beat GANs on Image Synthesis</em>。[2]</p>

<p>它的核心想法非常优雅：<br />
先保留一个无条件 diffusion 模型来提供“自然图像先验”，然后再额外训练一个分类器，告诉模型“怎样更像某个类别”。</p>

<h3 id="21-贝叶斯公式是整个故事的起点">2.1 贝叶斯公式是整个故事的起点</h3>

<p>对于类别条件 <code class="language-plaintext highlighter-rouge">c</code>，有：</p>

\[p(x_t|c) \propto p(c|x_t)\,p(x_t)\]

<p>两边取对数并对 <code class="language-plaintext highlighter-rouge">x_t</code> 求梯度，可得：</p>

\[\nabla_{x_t}\log p(x_t|c)=\nabla_{x_t}\log p(x_t)+\nabla_{x_t}\log p(c|x_t)\]

<p>这行式子的含义特别重要：</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">\nabla_{x_t}\log p(x_t)</code>：让样本更像“自然图像”</li>
  <li><code class="language-plaintext highlighter-rouge">\nabla_{x_t}\log p(c|x_t)</code>：让样本更像“类别 c”</li>
</ul>

<p>于是条件生成的 score 可以拆成：</p>

<blockquote>
  <p>无条件生成能力 + 分类器给的条件方向</p>
</blockquote>

<h3 id="22-它在采样时是怎么工作的">2.2 它在采样时是怎么工作的</h3>

<p>直觉化地写，Classifier Guidance 的采样过程可以理解为：</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Step 1: diffusion 模型在当前 x_t 上做一次前向
        得到无条件噪声预测或 score

Step 2: 分类器在同一个 x_t 上判断“这有多像类别 c”
        然后对 x_t 求梯度

Step 3: 用这个梯度去修正 diffusion 模型的去噪方向

Step 4: 按修正后的方向走一步，得到 x_{t-1}
</code></pre></div></div>

<p>在 Dhariwal &amp; Nichol 的论文里，这个修正以“对采样均值加上分类器梯度项”的形式写进算法；如果改写到更常见的 <code class="language-plaintext highlighter-rouge">\epsilon</code>-prediction 记号中，会看到大家熟悉的“沿着分类器梯度方向做引导”。<br />
需要注意的是：<strong>不同采样器、不同参数化下常数项会略有差异</strong>，所以教程里出现的公式看起来可能不完全一样，但核心思想是一致的。[2]</p>

<h3 id="23-这个方法为什么当时很重要">2.3 这个方法为什么当时很重要</h3>

<p>因为它第一次清楚地证明了：</p>

<blockquote>
  <p>diffusion 模型也可以像 GAN 一样，通过引导在“样本保真度”和“分布覆盖度”之间做可控权衡。</p>
</blockquote>

<p>Dhariwal &amp; Nichol 在 ImageNet 上展示了很强的结果：通过 classifier guidance，他们把 conditional diffusion 的质量显著拉高，并把 diffusion 真正推到了“能和当时顶级 GAN 正面竞争”的阶段。[2]</p>

<h3 id="24-但它有一个很重的代价分类器必须在噪声图上工作">2.4 但它有一个很重的代价：分类器必须在噪声图上工作</h3>

<p>这是 Classifier Guidance 最大的工程痛点。</p>

<p>普通分类器只见过干净图像 <code class="language-plaintext highlighter-rouge">x_0</code>，但 diffusion 采样时给它的是各种噪声水平下的 <code class="language-plaintext highlighter-rouge">x_t</code>。<br />
因此你不能直接拿一个普通 ImageNet 分类器来引导 diffusion，而必须训练一个<strong>噪声感知分类器</strong>：</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">train_noisy_classifier</span><span class="p">(</span><span class="n">x0</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
    <span class="n">t</span> <span class="o">=</span> <span class="n">sample_timestep</span><span class="p">()</span>
    <span class="n">noise</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x0</span><span class="p">)</span>
    <span class="n">x_t</span> <span class="o">=</span> <span class="n">add_noise</span><span class="p">(</span><span class="n">x0</span><span class="p">,</span> <span class="n">noise</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
    <span class="n">logits</span> <span class="o">=</span> <span class="n">classifier</span><span class="p">(</span><span class="n">x_t</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
    <span class="n">loss</span> <span class="o">=</span> <span class="n">cross_entropy</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">loss</span>
</code></pre></div></div>

<p>也就是说，整个系统变成了：</p>

<ol>
  <li>一个 diffusion 模型</li>
  <li>一个额外分类器</li>
  <li>分类器还要在所有噪声水平上都能稳定工作</li>
</ol>

<p>而且推理时为了拿到 <code class="language-plaintext highlighter-rouge">\nabla_{x_t}\log p(c|x_t)</code>，还需要对分类器做反向传播，这会带来额外显存和速度开销。</p>

<hr />

<h2 id="3-第二代方案classifier-free-guidance">3. 第二代方案：Classifier-Free Guidance</h2>

<p>真正把 diffusion 条件生成推成主流工业范式的，是 Ho 和 Salimans 在 2021 年提出的 <em>Classifier-Free Diffusion Guidance</em>。[3]</p>

<p>这篇工作的标题其实就已经说明了本质：</p>

<blockquote>
  <p>classifier guidance without a classifier</p>
</blockquote>

<p>也就是：</p>

<blockquote>
  <p>我还想要 guidance 的效果，但我不想再训练一个分类器。</p>
</blockquote>

<h3 id="31-关键替换把分类器梯度改写成两个-score-的差">3.1 关键替换：把“分类器梯度”改写成两个 score 的差</h3>

<p>从贝叶斯公式出发：</p>

\[p(c|x_t)=\frac{p(x_t|c)p(c)}{p(x_t)}\]

<p>对数化后对 <code class="language-plaintext highlighter-rouge">x_t</code> 求梯度：</p>

\[\nabla_{x_t}\log p(c|x_t)
=
\nabla_{x_t}\log p(x_t|c)
-\nabla_{x_t}\log p(x_t)\]

<p>这里 <code class="language-plaintext highlighter-rouge">p(c)</code> 对 <code class="language-plaintext highlighter-rouge">x_t</code> 来说是常数，所以梯度为 0。</p>

<p>这一步非常关键，因为它说明：</p>

<blockquote>
  <p>分类器梯度，本质上可以看成“条件 score”和“无条件 score”的差值。</p>
</blockquote>

<p>而 diffusion 模型本来就在学 score 的近似。<br />
所以如果同一个模型既能输出条件版本，又能输出无条件版本，那么分类器的作用就可以被“模型自己的两次前向”替代。</p>

<h3 id="32-从-score-形式到大家熟悉的-cfg-公式">3.2 从 score 形式到大家熟悉的 CFG 公式</h3>

<p>在常见 VP diffusion / <code class="language-plaintext highlighter-rouge">\epsilon</code>-prediction 的记号下，可以把这件事写成：</p>

\[\hat{\epsilon}
=
\epsilon_\theta(x_t,t,\varnothing)
+ w\cdot\left[\epsilon_\theta(x_t,t,c)-\epsilon_\theta(x_t,t,\varnothing)\right]\]

<p>其中：</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">\epsilon_\theta(x_t,t,c)</code>：有条件噪声预测</li>
  <li><code class="language-plaintext highlighter-rouge">\epsilon_\theta(x_t,t,\varnothing)</code>：无条件噪声预测</li>
  <li><code class="language-plaintext highlighter-rouge">w</code>：guidance scale</li>
</ul>

<p>有时你也会看到等价写法：</p>

\[\hat{\epsilon}
=
(1+w)\epsilon_\theta(x_t,t,c)-w\epsilon_\theta(x_t,t,\varnothing)\]

<p>两者完全等价，只是展开方式不同。</p>

<h3 id="33-为什么-score-能改写成噪声预测">3.3 为什么 score 能改写成噪声预测</h3>

<p>上面这一步里，很多人最容易卡住的地方是：</p>

<blockquote>
  <p>前面推的是 score，为什么后面突然变成了噪声预测 <code class="language-plaintext highlighter-rouge">\epsilon_\theta</code>？</p>
</blockquote>

<p>关键在于 VP diffusion 的前向分布：</p>

\[q(x_t|x_0)=\mathcal{N}\left(\sqrt{\bar{\alpha}_t}x_0,\,(1-\bar{\alpha}_t)I\right)\]

<p>对 <code class="language-plaintext highlighter-rouge">x_t</code> 求对数梯度：</p>

\[\nabla_{x_t}\log q(x_t|x_0)
=
-\frac{x_t-\sqrt{\bar{\alpha}_t}x_0}{1-\bar{\alpha}_t}
=
-\frac{\epsilon}{\sqrt{1-\bar{\alpha}_t}}\]

<p>因为：</p>

\[x_t-\sqrt{\bar{\alpha}_t}x_0=\sqrt{1-\bar{\alpha}_t}\,\epsilon\]

<p>所以在这个参数化下，score 和噪声只差一个与时间步有关的缩放因子。<br />
这也是为什么 diffusion 文献里常把：</p>

\[\nabla_{x_t}\log p_t(x_t|c)\]

<p>近似写成：</p>

\[-\frac{1}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t,t,c)\]

<p>严格地说，这里对应的是扰动后边缘分布的 score 近似；不同参数化和不同 sampler 下写法会有差别，但对理解 CFG 来说，这个关系已经足够用了。</p>

<h3 id="34-直觉上它到底在干什么">3.4 直觉上它到底在干什么</h3>

<p>把上面的公式拆开看：</p>

\[\epsilon_\theta(x_t,t,c)-\epsilon_\theta(x_t,t,\varnothing)\]

<p>这一项可以理解为：</p>

<blockquote>
  <p>条件 <code class="language-plaintext highlighter-rouge">c</code> 相对于“什么都不说”所带来的额外方向</p>
</blockquote>

<p>所以 CFG 其实是在做一件非常朴素的事：</p>

<ol>
  <li>先得到“自然去噪方向”</li>
  <li>再提取出“条件带来的额外偏移”</li>
  <li>把这部分偏移乘上一个更大的系数</li>
</ol>

<p>这就是为什么很多人把 CFG 理解成“方向放大器”。<br />
它不是凭空发明一个新方向，而是在<strong>有条件</strong>和<strong>无条件</strong>之间做对比，把“条件真正贡献的那一小段方向”放大。</p>

<hr />

<h2 id="4-训练阶段为什么一个模型能同时学会有条件和无条件">4. 训练阶段：为什么一个模型能同时学会有条件和无条件</h2>

<p>CFG 成立的前提是：</p>

<blockquote>
  <p>同一个模型既会做 <code class="language-plaintext highlighter-rouge">\epsilon(x_t,t,c)</code>，也会做 <code class="language-plaintext highlighter-rouge">\epsilon(x_t,t,\varnothing)</code>。</p>
</blockquote>

<p>Ho 和 Salimans 的做法很简单：<strong>训练时随机把条件丢掉</strong>。[3]</p>

<p>伪代码如下：</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">training_step</span><span class="p">(</span><span class="n">x0</span><span class="p">,</span> <span class="n">c</span><span class="p">):</span>
    <span class="n">t</span> <span class="o">=</span> <span class="n">sample_timestep</span><span class="p">()</span>
    <span class="n">noise</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x0</span><span class="p">)</span>
    <span class="n">x_t</span> <span class="o">=</span> <span class="n">add_noise</span><span class="p">(</span><span class="n">x0</span><span class="p">,</span> <span class="n">noise</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>

    <span class="k">if</span> <span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">p_drop</span><span class="p">:</span>
        <span class="n">c</span> <span class="o">=</span> <span class="n">null_condition</span>

    <span class="n">eps_pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x_t</span><span class="p">,</span> <span class="n">t</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
    <span class="n">loss</span> <span class="o">=</span> <span class="n">mse</span><span class="p">(</span><span class="n">eps_pred</span><span class="p">,</span> <span class="n">noise</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">loss</span>
</code></pre></div></div>

<p>其中 <code class="language-plaintext highlighter-rouge">p_drop</code> 通常设成 10% 到 20% 左右。</p>

<p>这意味着训练过程中模型会见到两类样本：</p>

<ul>
  <li>大多数时候：正常条件训练</li>
  <li>少数时候：条件被替换为空条件</li>
</ul>

<p>于是模型自然学会了两种行为模式：</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>输入真实条件 c   -&gt; 输出条件去噪结果
输入空条件 ∅     -&gt; 输出无条件去噪结果
</code></pre></div></div>

<p>所以推理时只要把同一个 <code class="language-plaintext highlighter-rouge">x_t</code> 喂给模型两次：</p>

<ul>
  <li>一次给真实条件</li>
  <li>一次给空条件</li>
</ul>

<p>就能得到 CFG 所需的两个分支。</p>

<hr />

<h2 id="5-classifier-guidance-和-cfg-到底差在哪里">5. Classifier Guidance 和 CFG 到底差在哪里</h2>

<p>把两代方法摆在一起，对比会非常清楚。</p>

<table>
  <thead>
    <tr>
      <th>维度</th>
      <th>Classifier Guidance</th>
      <th>Classifier-Free Guidance</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>需要几个模型</td>
      <td>2 个</td>
      <td>1 个</td>
    </tr>
    <tr>
      <td>是否需要额外分类器</td>
      <td>需要</td>
      <td>不需要</td>
    </tr>
    <tr>
      <td>分类器是否要适配噪声图</td>
      <td>需要</td>
      <td>不需要</td>
    </tr>
    <tr>
      <td>训练方式</td>
      <td>diffusion 无条件训练 + 分类器单独训练</td>
      <td>条件训练 + condition dropout</td>
    </tr>
    <tr>
      <td>推理开销</td>
      <td>diffusion 前向 + 分类器反向</td>
      <td>两次 diffusion 前向</td>
    </tr>
    <tr>
      <td>条件类型</td>
      <td>更适合离散类别</td>
      <td>任意 embedding 条件</td>
    </tr>
    <tr>
      <td>工程复杂度</td>
      <td>高</td>
      <td>低</td>
    </tr>
    <tr>
      <td>当前使用情况</td>
      <td>历史重要，但已非主流</td>
      <td>现代主流</td>
    </tr>
  </tbody>
</table>

<p>如果只看数学推导，这两者像是“同一家族的两个版本”。<br />
但如果从工程实现看，差距非常大。</p>

<p><br />
<img align="center" width="1000" src="https://liyongzhi.xyz/images/posts/diffusion-cfg-compare.svg" alt="Classifier Guidance versus Classifier-Free Guidance" />
<br /></p>

<p>如果你只想先抓住最核心的工程差异，那么看这张图就够了：<br />
<code class="language-plaintext highlighter-rouge">Classifier Guidance</code> 是“diffusion 模型 + 外部分类器”，<code class="language-plaintext highlighter-rouge">CFG</code> 则是“同一个 diffusion 模型跑两次”。</p>

<h3 id="51-为什么-classifier-guidance-会被边缘化">5.1 为什么 Classifier Guidance 会被边缘化</h3>

<p>主要有四个原因。</p>

<h4 id="原因-1分类器负担太重">原因 1：分类器负担太重</h4>

<p>你不只是多训练了一个网络，而是多训练了一个<strong>噪声鲁棒</strong>分类器。<br />
这件事本身就不轻，而且每换一种条件形式都要重来。</p>

<h4 id="原因-2推理要反向传播">原因 2：推理要反向传播</h4>

<p>Classifier Guidance 采样时需要拿分类器对 <code class="language-plaintext highlighter-rouge">x_t</code> 的梯度。<br />
这意味着：</p>

<ul>
  <li>需要保留更多中间激活</li>
  <li>显存压力更大</li>
  <li>速度通常不如“两次前向”的 CFG</li>
</ul>

<h4 id="原因-3条件类型太受限">原因 3：条件类型太受限</h4>

<p>分类器天然适合“猫 / 狗 / 车”这种离散标签。<br />
但现代生成任务的条件往往是：</p>

<ul>
  <li>长文本 prompt</li>
  <li>图像条件</li>
  <li>layout / depth / segmentation</li>
  <li>多模态混合 embedding</li>
</ul>

<p>这时候“训练一个噪声图上的条件分类器”就变得很不自然。</p>

<h4 id="原因-4高噪声阶段梯度不稳定">原因 4：高噪声阶段梯度不稳定</h4>

<p>当 <code class="language-plaintext highlighter-rouge">x_t</code> 很接近纯噪声时，分类器给出的判断很容易不可靠。<br />
这会导致引导方向 noisy、过激，甚至带来类似对抗样本的伪影。</p>

<hr />

<h2 id="6-为什么直接做条件生成还不够偏偏要有无条件分支">6. 为什么“直接做条件生成”还不够，偏偏要有无条件分支</h2>

<p>这往往是大家第一次学 CFG 时最困惑的点。</p>

<p>一个很自然的问题是：</p>

<blockquote>
  <p>如果模型本来就是条件模型，为什么不直接用 <code class="language-plaintext highlighter-rouge">\epsilon_\theta(x_t,t,c)</code> 去采样？<br />
为什么还要再算一次无条件分支？</p>
</blockquote>

<p>答案是：<strong>因为 CFG 不只是“让模型有条件”，而是在推理阶段对条件方向进行再加权。</strong>
这并不是说纯条件模型不能工作，而是说它少了一个可以在推理时显式放大条件信号的控制手柄。</p>

<h3 id="61-纯条件模型只是在做条件均值意义下的去噪">6.1 纯条件模型只是在做“条件均值意义下的去噪”</h3>

<p>如果 prompt 很宽泛，比如“a cat”，那么满足这个条件的图像分布其实很宽：</p>

<ul>
  <li>橘猫</li>
  <li>黑猫</li>
  <li>正脸</li>
  <li>侧脸</li>
  <li>写实</li>
  <li>插画</li>
</ul>

<p>单纯条件模型学到的是这个条件分布下的平均去噪规律。<br />
而 MSE 型训练目标天然倾向于学习“平均意义上最稳妥的预测”。</p>

<p>结果往往是：</p>

<ul>
  <li>条件满足了</li>
  <li>但语义不够“尖锐”</li>
  <li>细节和风格不够坚定</li>
</ul>

<p>CFG 做的事情则更像是在说：</p>

<blockquote>
  <p>我不仅要满足条件，我还要更坚定地朝条件特征前进。</p>
</blockquote>

<h3 id="62-无条件分支提供了一个基线">6.2 无条件分支提供了一个“基线”</h3>

<p>这一点非常重要。</p>

<p>有了无条件预测，你才能问出下面这个问题：</p>

<blockquote>
  <p>相比于“什么都不说”时的去噪方向，这个条件到底额外改变了什么？</p>
</blockquote>

<p>也就是这项：</p>

\[\epsilon_\theta(x_t,t,c)-\epsilon_\theta(x_t,t,\varnothing)\]

<p>这就是条件的“纯增量”。</p>

<p>没有无条件分支，你只有 <code class="language-plaintext highlighter-rouge">\epsilon_\theta(x_t,t,c)</code>，却不知道里面有多少是：</p>

<ul>
  <li>来自数据分布本身的通用图像先验</li>
  <li>来自条件 <code class="language-plaintext highlighter-rouge">c</code> 的额外要求</li>
</ul>

<p>CFG 恰恰把这两部分显式分离开了。</p>

<h3 id="63-从分布角度看cfg-可以理解为后验锐化">6.3 从分布角度看，CFG 可以理解为“后验锐化”</h3>

<p>很多教程会把 CFG 理解成对条件分布的 sharpen。<br />
直觉上它对应于：</p>

\[\tilde{p}(x|c)\propto \frac{p(x|c)^w}{p(x)^{w-1}}\]

<p>也就是在对数空间里放大条件分布相对于无条件分布的优势方向。</p>

<p>这个视角对于理解“为什么更贴 prompt，但多样性下降”非常有帮助：<br />
<code class="language-plaintext highlighter-rouge">w</code> 越大，分布越尖，语义通常越强，但 mode coverage 往往会下降。</p>

<p>这里要加一个重要注记：</p>

<blockquote>
  <p>上面这个“锐化分布”视角是一个非常有用的直觉，但并不是对所有离散采样器、所有有限步采样过程都严格成立的完整结论。</p>
</blockquote>

<p>近年的理论工作开始更仔细地讨论 CFG 的有限步行为，以及它为什么会出现“过饱和、模式坍缩、编辑不可逆”等副作用；例如 CFG++ 就把部分问题解释为 off-manifold 现象。[7]</p>

<hr />

<h2 id="7-cfg-的实际效果为什么它能长期统治文本生成图像">7. CFG 的实际效果，为什么它能长期统治文本生成图像</h2>

<p>CFG 能成为主流，不只是因为“省了一个分类器”，更因为它刚好踩中了现代大模型系统的需求。</p>

<h3 id="71-它天然支持任意条件-embedding">7.1 它天然支持任意条件 embedding</h3>

<p>只要条件能被编码成向量，CFG 就能工作：</p>

<ul>
  <li>class embedding</li>
  <li>text encoder 输出</li>
  <li>image encoder 输出</li>
  <li>audio embedding</li>
  <li>layout / depth / pose control signal</li>
</ul>

<p>这跟文本生成图像时代的需求几乎完美匹配。</p>

<h3 id="72-它给了推理阶段一个可调节旋钮">7.2 它给了推理阶段一个可调节旋钮</h3>

<p><code class="language-plaintext highlighter-rouge">guidance scale</code> 是一个极其实用的控制参数。</p>

<ul>
  <li>小一些：更多样，但可能没那么贴 prompt</li>
  <li>大一些：更贴 prompt，但更容易失真、过饱和、重复</li>
</ul>

<p>这让同一个基础模型能覆盖很多场景，而不必为每种“对齐强度”重新训练一份。</p>

<h3 id="73-它非常契合-latent-diffusion--stable-diffusion-这类架构">7.3 它非常契合 latent diffusion / Stable Diffusion 这类架构</h3>

<p>现代文本生成图像系统通常采用：</p>

<ol>
  <li>文本编码器把 prompt 编成 embedding</li>
  <li>latent diffusion / UNet 在潜空间做去噪</li>
  <li>采样时同时跑 conditional 和 unconditional 分支</li>
  <li>用 CFG 线性组合</li>
</ol>

<p>这套接口很简单，也很模块化。<br />
所以从 Stable Diffusion 到许多后来的文本生成图像系统，CFG 都成为了默认推理机制。</p>

<hr />

<h2 id="8-cfg-的副作用为什么-guidance-scale-不能无限加大">8. CFG 的副作用：为什么 guidance scale 不能无限加大</h2>

<p>CFG 不是越大越好。</p>

<p>如果把 <code class="language-plaintext highlighter-rouge">w</code> 开得很高，常见问题包括：</p>

<ul>
  <li>图像过饱和</li>
  <li>纹理僵硬</li>
  <li>构图重复</li>
  <li>多样性下降</li>
  <li>细节出现“被强行往 prompt 上扯”的伪影</li>
</ul>

<p>这背后的直觉并不难理解：</p>

<blockquote>
  <p>你在不断放大“条件增量方向”，但这个方向本来只是一个局部修正。<br />
放大过头，就会从“更对齐”变成“过度纠偏”。</p>
</blockquote>

<p>很多用户在 Stable Diffusion 里都有很直观的经验：</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">scale</code> 太小，图像“没听话”</li>
  <li><code class="language-plaintext highlighter-rouge">scale</code> 太大，图像“过于用力”</li>
</ul>

<p>从理论和经验上看，CFG 一直在做一件 trade-off：</p>

<blockquote>
  <p>对齐更强，通常意味着多样性更弱。</p>
</blockquote>

<p>这也是后来很多改进工作的出发点。</p>

<hr />

<h2 id="9-后续发展为什么蒸馏模型常常不再需要-cfg">9. 后续发展：为什么蒸馏模型常常“不再需要 CFG”</h2>

<p>理解了 CFG 之后，再看 few-step 模型就顺了。</p>

<p>很多人第一次看到 LCM、SDXL Turbo 这类模型时会觉得奇怪：</p>

<blockquote>
  <p>为什么原始模型要几十步、还要 CFG；<br />
蒸馏后的模型却几步就行，甚至不再依赖传统 CFG？</p>
</blockquote>

<p>答案是：</p>

<blockquote>
  <p>因为 CFG 的效果可以在训练时被“蒸”进学生模型里。</p>
</blockquote>

<h3 id="91-progressive-distillation先解决步数太多">9.1 Progressive Distillation：先解决“步数太多”</h3>

<p>Salimans 和 Ho 在 2022 年提出 <em>Progressive Distillation for Fast Sampling of Diffusion Models</em>，核心思想是把一个多步 deterministic sampler 逐轮蒸成更少步数的模型，每一轮把步数减半。[4]</p>

<p>它解决的是：</p>

<blockquote>
  <p>diffusion 太慢，如何把 8192 步、1024 步慢慢蒸成 4 步？</p>
</blockquote>

<p>这一步不一定直接针对 CFG，但为后面 few-step 生成打了基础。</p>

<h3 id="92-consistency-models直接学习从噪声到数据的快速映射">9.2 Consistency Models：直接学习从噪声到数据的快速映射</h3>

<p>Song 等人在 2023 年提出 <em>Consistency Models</em>，把目标推进到“一步或少步生成”。[5]</p>

<p>它的关键思想是让模型学会不同时间点之间的一致映射，从而绕开传统 diffusion 逐步积分的高成本。</p>

<h3 id="93-latent-consistency-models把-cfg-蒸进-latent-diffusion-体系">9.3 Latent Consistency Models：把 CFG 蒸进 latent diffusion 体系</h3>

<p>真正和现代文本生成图像工作流贴得很近的是 2023 年的 <em>Latent Consistency Models (LCM)</em>。[6]</p>

<p>LCM 很关键的一句话是：</p>

<blockquote>
  <p>它是从<strong>预训练的 classifier-free guided diffusion models</strong> 高效蒸馏出来的。</p>
</blockquote>

<p>换句话说，teacher 本身就带着 CFG 的行为。<br />
学生模型学习的是：</p>

<blockquote>
  <p>teacher 做完 guidance 之后的结果</p>
</blockquote>

<p>于是推理阶段就不必再显式执行：</p>

<ol>
  <li>一次 conditional 前向</li>
  <li>一次 unconditional 前向</li>
  <li>线性组合</li>
</ol>

<p>学生模型已经把“有 guidance 的好处”折进自己参数里了。</p>

<h3 id="94-adversarial-diffusion-distillation把-few-step-做到更激进">9.4 Adversarial Diffusion Distillation：把 few-step 做到更激进</h3>

<p>2023 年的 <em>Adversarial Diffusion Distillation (ADD)</em> 更进一步，把 few-step / one-step 的质量继续往上推。[8]</p>

<p>它利用预训练 diffusion 模型作为 teacher signal，再加上 adversarial loss，目标是在极少步数下依然维持高质量图像。</p>

<p>所以如果把这一整条线串起来，你会得到一个很清晰的演化逻辑：</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>先有多步 diffusion
-&gt; 再有 CFG，让条件生成更强
-&gt; 再有蒸馏，把“多步 + CFG 的能力”压缩进少步学生模型
-&gt; 最后出现 few-step / one-step 的实用系统
</code></pre></div></div>

<p>这也是为什么今天很多快速模型虽然表面上“不再跑传统 CFG”，但它们背后的 teacher 往往仍然深受 CFG 体系影响。</p>

<hr />

<h2 id="10-最近两年的一些延伸大家在改进-cfg-的什么">10. 最近两年的一些延伸：大家在改进 CFG 的什么</h2>

<p>如果说 2021 到 2023 年的主线是“把 CFG 变成标准件”，那么 2024 到 2025 年的很多工作则是在回答：</p>

<blockquote>
  <p>CFG 很有用，但它到底哪里还不够好？</p>
</blockquote>

<h3 id="101-self-attention-guidance不用额外条件也能做训练自由的引导">10.1 Self-Attention Guidance：不用额外条件，也能做训练自由的引导</h3>

<p>Hong 等人在 <em>Self-Attention Guidance</em> 中提出，除了 classifier guidance 和 CFG 之外，还可以利用模型内部 self-attention 信息来做 guidance。[9]</p>

<p>这个方向的重要意义在于：</p>

<ul>
  <li>guidance 不一定非得来自外部分类器</li>
  <li>guidance 也不一定非得来自条件 dropout 训练</li>
  <li>可以从模型内部结构本身提取“纠偏信号”</li>
</ul>

<h3 id="102-pag把-guidance-扩展到无条件与下游任务场景">10.2 PAG：把 guidance 扩展到无条件与下游任务场景</h3>

<p>2024 年的 <em>Perturbed-Attention Guidance (PAG)</em> 则进一步展示：<br />
即使在 unconditional generation 或某些 CFG 不方便使用的任务里，也可以通过扰动 attention 构造 guidance 信号。[10]</p>

<p>这说明一个更大的趋势：</p>

<blockquote>
  <p>“guidance” 已经不再只是一条固定公式，而是在演化成一个更宽的推理控制框架。</p>
</blockquote>

<h3 id="103-cfg讨论-vanilla-cfg-的-off-manifold-问题">10.3 CFG++：讨论 vanilla CFG 的 off-manifold 问题</h3>

<p>2025 年的 <em>CFG++</em> 指出，传统 CFG 的一些副作用并不一定是 diffusion 本身的问题，而可能和 CFG 把采样轨迹推离数据流形有关。[7]</p>

<p>这类工作之所以值得关注，是因为它们开始从“经验调 scale”走向：</p>

<ul>
  <li>更系统地理解 CFG 为什么有效</li>
  <li>更具体地解释 CFG 为什么会失真</li>
  <li>更有针对性地修复它的缺点</li>
</ul>

<hr />

<h2 id="11-一张总图把整条知识脉络串起来">11. 一张总图，把整条知识脉络串起来</h2>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>DDPM / score-based diffusion
  先学会从噪声中恢复“自然样本”
  核心对象是 p(x) 或其 score

        |
        v

Classifier Guidance (2021)
  用无条件 diffusion 提供 p(x)
  再训练噪声分类器提供 ∇ log p(c|x_t)
  条件 score = 无条件 score + 分类器梯度

        |
        v

Classifier-Free Guidance (2021/2022)
  不再训练分类器
  通过条件 dropout 让同一模型同时学会：
    ε(x_t, t, c)
    ε(x_t, t, ∅)
  采样时做：
    ε_cfg = ε_∅ + w (ε_c - ε_∅)

        |
        v

现代文本生成图像系统
  CFG 成为默认推理机制
  通过 guidance scale 控制 prompt 对齐和多样性

        |
        v

Few-step / Distillation 路线
  Progressive Distillation
  Consistency Models
  LCM
  ADD

  把“多步 + CFG”的效果蒸进学生模型

        |
        v

近期改进
  SAG / PAG / CFG++
  从训练自由 guidance、attention guidance、
  以及理论修正等角度继续优化 CFG
</code></pre></div></div>

<hr />

<h2 id="12-一个最常见的误区">12. 一个最常见的误区</h2>

<p>很多教程会把 CFG 说成：</p>

<blockquote>
  <p>“就是条件减去无条件，再乘一个 scale。”</p>
</blockquote>

<p>这当然没错，但如果只停在这一步，会漏掉最关键的理解：</p>

<blockquote>
  <p><code class="language-plaintext highlighter-rouge">ε_cond - ε_uncond</code> 不是一个拍脑袋的 engineering trick，<br />
它来自贝叶斯公式下“分类器梯度 = 条件 score - 无条件 score”的推导。</p>
</blockquote>

<p>也就是说，CFG 不是“经验上好用的 hack”，而是一个有明确 probabilistic 来历、又极度工程友好的近似方案。</p>

<p>它真正厉害的地方在于：</p>

<ul>
  <li>数学上和 classifier guidance 是一脉相承的</li>
  <li>工程上却省掉了最麻烦的那个分类器</li>
  <li>还顺手把条件接口扩展成了任意 embedding</li>
</ul>

<p>这就是为什么它几乎成为现代 diffusion 条件生成的默认答案。</p>

<hr />

<h2 id="13-总结">13. 总结</h2>

<p>如果只用一句话总结整条脉络，我会写成：</p>

<blockquote>
  <p><code class="language-plaintext highlighter-rouge">Classifier Guidance</code> 证明了 diffusion 可以被“引导”；<br />
<code class="language-plaintext highlighter-rouge">Classifier-Free Guidance</code> 则把这种引导从一个昂贵、受限的两模型系统，变成了一个几乎所有条件 diffusion 都能直接使用的标准模块。</p>
</blockquote>

<p>更具体地说：</p>

<ol>
  <li>无条件 diffusion 只学 <code class="language-plaintext highlighter-rouge">p(x)</code>，不知道用户要什么。</li>
  <li>Classifier Guidance 用额外分类器给出 <code class="language-plaintext highlighter-rouge">\nabla \log p(c|x_t)</code>，第一次把条件引导明确做进 diffusion 采样。</li>
  <li>CFG 发现这个分类器梯度可以由“条件 score - 无条件 score”替代，于是只靠一个模型、两次前向就能完成引导。</li>
  <li>CFG 因为简单、通用、兼容文本条件，最终成为文本生成图像时代的主流。</li>
  <li>后续的蒸馏与 consistency 路线，又把“多步 + CFG”的能力进一步压缩进 few-step 模型。</li>
</ol>

<p>所以从历史上看，CFG 不是 diffusion 里的一个小技巧，它几乎就是现代条件 diffusion 能真正大规模落地的关键转折点之一。</p>

<hr />

<h2 id="参考资料">参考资料</h2>

<ol>
  <li>
    <p>Ho, Jain, Abbeel. <em>Denoising Diffusion Probabilistic Models</em>. NeurIPS 2020.<br />
<a href="https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf">Paper</a></p>
  </li>
  <li>
    <p>Dhariwal, Nichol. <em>Diffusion Models Beat GANs on Image Synthesis</em>. NeurIPS 2021.<br />
<a href="https://proceedings.neurips.cc/paper/2021/file/49ad23d1ec9fa4bd8d77d02681df5cfa-Paper.pdf">Paper</a></p>
  </li>
  <li>
    <p>Ho, Salimans. <em>Classifier-Free Diffusion Guidance</em>. NeurIPS 2021 Workshop / OpenReview.<br />
<a href="https://openreview.net/forum?id=qw8AKxfYbI">OpenReview</a></p>
  </li>
  <li>
    <p>Salimans, Ho. <em>Progressive Distillation for Fast Sampling of Diffusion Models</em>. ICLR 2022.<br />
<a href="https://arxiv.org/abs/2202.00512">arXiv</a></p>
  </li>
  <li>
    <p>Song, Dhariwal, Chen, Sutskever. <em>Consistency Models</em>. ICML 2023.<br />
<a href="https://proceedings.mlr.press/v202/song23a.html">PMLR</a></p>
  </li>
  <li>
    <p>Luo, Tan, Huang, Li, Zhao. <em>Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference</em>. 2023.<br />
<a href="https://arxiv.org/abs/2310.04378">arXiv</a></p>
  </li>
  <li>
    <p>Chung, Kim, Park, Nam, Ye. <em>CFG++: Manifold-constrained Classifier Free Guidance for Diffusion Models</em>. ICLR 2025.<br />
<a href="https://openreview.net/forum?id=E77uvbOTtp">OpenReview</a></p>
  </li>
  <li>
    <p>Sauer, Lorenz, Blattmann, Rombach. <em>Adversarial Diffusion Distillation</em>. 2023.<br />
<a href="https://arxiv.org/abs/2311.17042">arXiv</a></p>
  </li>
  <li>
    <p>Hong, Lee, Jang, Kim. <em>Improving Sample Quality of Diffusion Models Using Self-Attention Guidance</em>. ICCV 2023.<br />
<a href="https://openaccess.thecvf.com/content/ICCV2023/html/Hong_Improving_Sample_Quality_of_Diffusion_Models_Using_Self-Attention_Guidance_ICCV_2023_paper.html">Paper</a></p>
  </li>
  <li>
    <p>Ahn et al. <em>Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance</em>. ECCV 2024 / arXiv.<br />
<a href="https://arxiv.org/abs/2403.17377">arXiv</a></p>
  </li>
</ol>]]></content><author><name>李勇志 (Yongzhi Li)</name><email>yongzhili@pku.edu.cn</email></author><category term="blog" /><category term="diffusion" /><category term="generative model" /><category term="machine learning" /><category term="computer vision" /><summary type="html"><![CDATA[Diffusion 模型发展到今天，CFG 几乎已经成了文本生成图像系统里的“默认组件”。 但很多人第一次看到它时都会困惑：]]></summary></entry><entry><title type="html">从 DDPO 到 Flow-GRPO：一文看懂 Diffusion 模型的强化学习过程与发展脉络</title><link href="https://liyongzhi.xyz/posts/2026/04/diffusion-rl/" rel="alternate" type="text/html" title="从 DDPO 到 Flow-GRPO：一文看懂 Diffusion 模型的强化学习过程与发展脉络" /><published>2026-04-20T00:00:00+08:00</published><updated>2026-04-20T00:00:00+08:00</updated><id>https://liyongzhi.xyz/posts/2026/04/blog-post-diffusion-rl</id><content type="html" xml:base="https://liyongzhi.xyz/posts/2026/04/diffusion-rl/"><![CDATA[<p>Diffusion 模型最初是按“去噪 MSE / 似然近似”来训练的，但真正上线时，我们更关心的往往不是似然，而是：</p>

<ul>
  <li>人类是否更喜欢这张图</li>
  <li>图像和 prompt 是否更对齐</li>
  <li>视频动作是否更连贯</li>
  <li>输出是否更安全</li>
  <li>结果是否更符合物理或任务约束</li>
</ul>

<p>这类目标通常没有一个漂亮、统一、可微、稳定的监督损失。<br />
于是 2023 年开始，一批工作把 Diffusion 的采样过程重新解释成<strong>多步决策</strong>，再用 RL、偏好优化、reward backprop 等方法对它做后训练。</p>

<p>这篇文章想讲清两件事：</p>

<ol>
  <li>Diffusion 模型为什么能被看成一个 RL 问题，以及一次 RL fine-tuning 到底在做什么。</li>
  <li>这条线是如何从 <code class="language-plaintext highlighter-rouge">DDPO / DPOK</code>，发展到 <code class="language-plaintext highlighter-rouge">Diffusion-DPO / D3PO</code>、<code class="language-plaintext highlighter-rouge">DRaFT / AlignProp</code>、视频对齐，再走到 <code class="language-plaintext highlighter-rouge">Flow-GRPO</code> 和 2026 年更高效的方法。</li>
</ol>

<p><br />
<img align="center" width="1000" src="https://liyongzhi.xyz/images/posts/diffusion-rl-evolution.svg" alt="Diffusion RL evolution timeline" />
<br /></p>

<p>上图可以先当全文导航：<br />
2023 年是“把去噪过程变成决策过程”的起点；2023 年下半年到 2024 年，方法开始沿着不同反馈类型分叉；2025 年进入 Flow Matching 时代；到 2026 年，研究重点明显转向了<strong>效率、credit assignment 和对 reverse process likelihood 的替代</strong>。</p>

<hr />

<h2 id="1-为什么要用-rl-优化-diffusion">1. 为什么要用 RL 优化 Diffusion</h2>

<p>标准 Diffusion 训练，学到的是一个“如何把噪声慢慢拉回数据分布”的模型。<br />
在 DDPM 记号下，它通常优化的是噪声预测误差：</p>

\[\mathcal{L}_{\text{diffusion}} = \mathbb{E}\left[\|\epsilon - \epsilon_\theta(x_t, t, c)\|^2\right]\]

<p>这个目标和“生成结果更符合人类偏好”之间，并不是一回事。</p>

<p>更准确地说，预训练阶段优化的是：</p>

<blockquote>
  <p>生成样本要像训练分布中的样本。</p>
</blockquote>

<p>而后训练阶段往往优化的是：</p>

<blockquote>
  <p>在不要严重偏离预训练分布的前提下，让样本在某个外部指标上拿更高分。</p>
</blockquote>

<p>外部指标可以是：</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">ImageReward</code>、<code class="language-plaintext highlighter-rouge">HPSv2</code> 这类偏好或美学 reward model</li>
  <li><code class="language-plaintext highlighter-rouge">CLIP</code>、VLM 或 OCR-based 的对齐指标</li>
  <li>视频里的时序一致性、运动平滑性</li>
  <li>分子、蛋白、材料里的任务分数</li>
  <li>人类偏好对本身</li>
</ul>

<p>RL 的价值就在这里：<br />
你不再要求目标必须长得像一个监督学习 loss，而是只要求它能为最终样本给出一个分数，或者至少给出偏好关系。</p>

<hr />

<h2 id="2-关键重写去噪过程其实是一个-mdp">2. 关键重写：去噪过程其实是一个 MDP</h2>

<p>Diffusion RL 的真正起点，不是某个具体算法，而是这个重写：</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>state      s_t = (x_t, t, c)
action     a_t = sample or predict the next denoising move
transition x_t -&gt; x_{t-1}
reward     r(x_0, c) or preference over final samples
policy     the diffusion / flow model itself
</code></pre></div></div>

<p>如果用 DDPM 风格的随机反向过程来看，模型每一步都定义了一个条件高斯分布：</p>

\[p_\theta(x_{t-1}\mid x_t, c)=\mathcal{N}(\mu_\theta(x_t,t,c), \sigma_t^2 I)\]

<p>这时“动作”可以理解为：<br />
<strong>在状态 <code class="language-plaintext highlighter-rouge">x_t</code> 下，策略选择了一个从高斯反向转移里采样出来的 <code class="language-plaintext highlighter-rouge">x_{t-1}</code>。</strong></p>

<p>于是整条采样轨迹就像：</p>

<div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>x_T -&gt; x_{T-1} -&gt; ... -&gt; x_1 -&gt; x_0
</code></pre></div></div>

<p>最后只在 <code class="language-plaintext highlighter-rouge">x_0</code> 上拿到一个终止 reward。<br />
这正是 RL 里最经典、也最麻烦的一类问题：</p>

<ul>
  <li>奖励是稀疏的</li>
  <li>credit assignment 跨越很多步</li>
  <li>你还不希望模型把预训练学到的图像先验彻底破坏掉</li>
</ul>

<h3 id="21-compute_log_prob-到底在算什么">2.1 <code class="language-plaintext highlighter-rouge">compute_log_prob</code> 到底在算什么</h3>

<p>很多人第一次看 DDPO 代码时，最困惑的是 <code class="language-plaintext highlighter-rouge">compute_log_prob</code>。<br />
它算的其实非常朴素：</p>

<blockquote>
  <p>在当前策略下，模型把 <code class="language-plaintext highlighter-rouge">x_t</code> 变成这一步实际采样到的 <code class="language-plaintext highlighter-rouge">x_{t-1}</code>，这件事的对数概率有多大。</p>
</blockquote>

<p>因为反向过程是高斯分布，所以：</p>

\[\log p_\theta(x_{t-1}\mid x_t,c)
\propto
-\frac{1}{2\sigma_t^2}\|x_{t-1}-\mu_\theta(x_t,t,c)\|^2\]

<p>这件事之所以重要，是因为 policy gradient 需要的正是：</p>

\[\nabla_\theta \log \pi_\theta(a_t\mid s_t)\]

<p>在 Diffusion 里，它就变成了每一步反向转移的 log-prob gradient。</p>

<h3 id="22-为什么-ddpm-比-ddim-更容易做-policy-gradient">2.2 为什么 DDPM 比 DDIM 更容易做 policy gradient</h3>

<p>如果你用的是 DDPM 风格采样，每一步天然有随机性，高斯转移和 log-prob 都是良定义的。<br />
但 DDIM 和很多 Flow Matching 采样器本质上更接近确定性 ODE，这会带来两个麻烦：</p>

<ol>
  <li>exploration 不够自然</li>
  <li>likelihood / log-prob 不容易直接写出来</li>
</ol>

<p>这也是为什么后面 <code class="language-plaintext highlighter-rouge">Flow-GRPO</code> 需要专门做 <strong>ODE-to-SDE conversion</strong>，而 2026 年一些方法开始尝试<strong>不再直接依赖 reverse-process likelihood</strong>。[11][13]</p>

<hr />

<h2 id="3-一次-diffusion-rl-训练迭代到底发生了什么">3. 一次 Diffusion RL 训练迭代到底发生了什么</h2>

<p>先不急着分论文，先看一次最典型的 <code class="language-plaintext highlighter-rouge">DDPO / PPO</code> 风格训练循环。<br />
把细节都压缩掉，它本质上只有五步：</p>

<h3 id="31-第一步用旧策略生成完整轨迹">3.1 第一步：用旧策略生成完整轨迹</h3>

<p>给一批 prompt，用当前模型 <code class="language-plaintext highlighter-rouge">\theta_{\text{old}}</code> 从噪声开始生成图像，并记录整条去噪轨迹：</p>

<ul>
  <li>每个时间步的 <code class="language-plaintext highlighter-rouge">x_t</code></li>
  <li>实际采样得到的 <code class="language-plaintext highlighter-rouge">x_{t-1}</code></li>
  <li>旧策略下对应的 <code class="language-plaintext highlighter-rouge">old_log_prob</code></li>
</ul>

<h3 id="32-第二步对最终样本打分">3.2 第二步：对最终样本打分</h3>

<p>只在最后生成出的 <code class="language-plaintext highlighter-rouge">x_0</code> 上调用 reward：</p>

\[r = r(x_0, c)\]

<p>这个 reward 可以是：</p>

<ul>
  <li>美学分数</li>
  <li>prompt 对齐分数</li>
  <li>安全分数</li>
  <li>视频时序 reward</li>
  <li>人类偏好数据训练出来的 reward model</li>
</ul>

<h3 id="33-第三步把-reward-变成-advantage">3.3 第三步：把 reward 变成 advantage</h3>

<p>最简单时，整条轨迹共享同一个终止 reward；更稳一些时会做 batch normalization、baseline 或组内归一化：</p>

\[A = \frac{r - \mathrm{mean}(r)}{\mathrm{std}(r) + \epsilon}\]

<h3 id="34-第四步用新参数重算每一步-log-prob">3.4 第四步：用新参数重算每一步 log prob</h3>

<p>对同一条轨迹，用当前正在更新的参数重新计算：</p>

\[\rho_t = \exp\left(\log p_\theta(x_{t-1}\mid x_t,c) - \log p_{\theta_{\text{old}}}(x_{t-1}\mid x_t,c)\right)\]

<p>然后套进 PPO clip 或带 KL 的目标里。</p>

<h3 id="35-第五步让高-reward-轨迹更可能低-reward-轨迹更不可能">3.5 第五步：让高 reward 轨迹更可能、低 reward 轨迹更不可能</h3>

<p>如果一张图最终分数高，就提升这条去噪轨迹上各步动作的概率；<br />
如果一张图分数低，就降低这些动作的概率。</p>

<p>直觉上，Diffusion RL 学到的不是某个“神秘奖励魔法”，而是：</p>

<blockquote>
  <p>哪些去噪路径更容易通向人类真正想要的结果。</p>
</blockquote>

<hr />

<h2 id="4-发展脉络这条线是怎么长出来的">4. 发展脉络：这条线是怎么长出来的</h2>

<p>到这里已经能理解“过程”了，接下来再看历史就会顺很多。<br />
我更推荐按<strong>反馈类型</strong>和<strong>建模约束</strong>来看，而不是只按年份背论文名。</p>

<h3 id="41-2023从-reward-model-到-online-rl">4.1 2023：从 reward model 到 online RL</h3>

<p>2023 年最重要的转折，是社区开始承认：</p>

<blockquote>
  <p>Diffusion 的预训练目标和下游目标不一致，所以需要单独的 post-training。</p>
</blockquote>

<p>这一年最早的标志性工作之一是 <code class="language-plaintext highlighter-rouge">ImageReward</code>。它不仅提出了一个通用文本生成图像 reward model，还给出了 <code class="language-plaintext highlighter-rouge">ReFL</code>，把 reward feedback 直接用于模型调优。[1]</p>

<p>紧接着，<code class="language-plaintext highlighter-rouge">DDPO</code> 在 2023 年 5 月把 denoising 明确建模成多步决策过程，并系统引入 policy gradient / PPO 风格更新。[2]</p>

<p>几天后，<code class="language-plaintext highlighter-rouge">DPOK</code> 进一步强调了 <strong>KL regularization</strong> 的重要性：<br />
你不只是要优化 reward，还要约束模型不要偏离预训练分布太远，否则很快就会 reward hacking、图像质量塌陷。[3]</p>

<p>这一阶段的核心思想可以压缩成一句话：</p>

<blockquote>
  <p>先承认 reward 存在，再把 Diffusion 当策略来优化。</p>
</blockquote>

<h3 id="42-2023-下半年到-2024按反馈类型开始分叉">4.2 2023 下半年到 2024：按反馈类型开始分叉</h3>

<p>当“Diffusion 可以做 RL 后训练”这个大门打开后，下一步的问题自然变成：</p>

<blockquote>
  <p>你手里到底有什么反馈信号？</p>
</blockquote>

<p>如果答案不同，方法也会不同。</p>

<h4 id="路线-a你有黑盒-scalar-reward">路线 A：你有黑盒 scalar reward</h4>

<p>那就最适合 <code class="language-plaintext highlighter-rouge">DDPO / DPOK</code> 这种 policy gradient 路线：</p>

<ul>
  <li>reward 可黑盒</li>
  <li>不要求可微</li>
  <li>但方差高，采样贵</li>
</ul>

<h4 id="路线-b你有偏好对没有-reward-model">路线 B：你有偏好对，没有 reward model</h4>

<p>那就更接近 <code class="language-plaintext highlighter-rouge">DPO</code> 家族。</p>

<p><code class="language-plaintext highlighter-rouge">Diffusion-DPO</code> 在 2023 年 11 月把 LLM 里的 DPO 思路迁到 text-to-image diffusion：不先训 reward model，而是直接用人类偏好对优化模型相对偏好概率。[6]</p>

<p>同月提交、后被 CVPR 2024 接收的 <code class="language-plaintext highlighter-rouge">D3PO</code>（<em>Using Human Feedback to Fine-tune Diffusion Models without Any Reward Model</em>）则进一步把“无 reward model 的直接偏好优化”扩展到多步 denoising MDP 视角。[7]</p>

<p>这一路的核心不是“直接做 RL”，而是：</p>

<blockquote>
  <p>如果你已经拿到了 winner / loser 对，就没必要再绕一圈学一个 reward model。</p>
</blockquote>

<h4 id="路线-c你的-reward-本身可微">路线 C：你的 reward 本身可微</h4>

<p>那就完全没必要忍受 REINFORCE / PPO 的高方差。</p>

<p><code class="language-plaintext highlighter-rouge">DRaFT</code> 在 2023 年 9 月提出，直接把 reward 梯度穿过采样过程反传回来，并进一步给出：</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">DRaFT-K</code>：只反传最后 <code class="language-plaintext highlighter-rouge">K</code> 步</li>
  <li><code class="language-plaintext highlighter-rouge">DRaFT-LV</code>：在 <code class="language-plaintext highlighter-rouge">K=1</code> 时进一步降方差</li>
</ul>

<p>它的关键 trade-off 很直白：</p>

<ul>
  <li>梯度更准</li>
  <li>样本效率更高</li>
  <li>但显存和反传成本更高</li>
</ul>

<p><code class="language-plaintext highlighter-rouge">AlignProp</code> 也属于这条 reward-backprop 路线，并通过 <code class="language-plaintext highlighter-rouge">LoRA + gradient checkpointing</code> 让直接反传更实用。[5]</p>

<p>这里有一个需要澄清的点：<br />
<code class="language-plaintext highlighter-rouge">AlignProp</code> 的 arXiv 条目后来被作者撤回，并在页面上注明内容被后续工作吸收；但“通过 reward gradient 直接调 diffusion”的路线本身没有消失，反而继续扩展到了视频。[5][9]</p>

<hr />

<h2 id="5-三条主技术路线到底该怎么区分">5. 三条主技术路线，到底该怎么区分</h2>

<p>如果把 2023 到 2026 的方法压缩成一个表，最有用的不是“谁早谁晚”，而是下面这张图。</p>

<p><br />
<img align="center" width="1000" src="https://liyongzhi.xyz/images/posts/diffusion-rl-selector.svg" alt="How to choose among diffusion RL methods" />
<br /></p>

<p>再配合这张表会更直观：</p>

<table>
  <thead>
    <tr>
      <th>路线</th>
      <th>代表方法</th>
      <th>你需要什么反馈</th>
      <th>优点</th>
      <th>代价</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Policy Gradient</td>
      <td>DDPO, DPOK, Flow-GRPO</td>
      <td>黑盒 scalar reward</td>
      <td>通用，reward 不必可微</td>
      <td>方差高，采样贵</td>
    </tr>
    <tr>
      <td>Preference Optimization</td>
      <td>Diffusion-DPO, D3PO</td>
      <td>winner / loser 偏好对</td>
      <td>不必先训 reward model</td>
      <td>需要成对偏好数据，likelihood 近似更复杂</td>
    </tr>
    <tr>
      <td>Direct Backprop</td>
      <td>DRaFT, AlignProp, Video Reward Gradients</td>
      <td>可微 reward</td>
      <td>梯度低方差，样本效率高</td>
      <td>显存大，reward 必须可微</td>
    </tr>
  </tbody>
</table>

<p>这里最值得记住的一点是：</p>

<blockquote>
  <p>这三条线不是互相否定，而是在回答不同的问题。</p>
</blockquote>

<p>不是所有场景都该上 PPO，也不是所有场景都该上 DPO。<br />
真正的分界线其实是：<strong>你的反馈是什么形式、能不能反传、采样预算有多贵。</strong></p>

<hr />

<h2 id="6-视频生成把问题又抬高了一个量级">6. 视频生成把问题又抬高了一个量级</h2>

<p>图像里已经够难了，视频里问题会立刻放大。</p>

<p>原因很简单。<br />
视频 reward 往往不是一个“单帧好不好看”的问题，而是至少同时包含：</p>

<ul>
  <li>单帧质量</li>
  <li>文本对齐</li>
  <li>时序一致性</li>
  <li>运动合理性</li>
  <li>镜头与物理连续性</li>
</ul>

<p>而视频采样又比图像慢得多，显存开销也高得多。</p>

<h3 id="61-instructvideo把视频-reward-fine-tuning-重新表述成-editing">6.1 InstructVideo：把视频 reward fine-tuning 重新表述成 editing</h3>

<p><code class="language-plaintext highlighter-rouge">InstructVideo</code> 在 2023 年 12 月提交、后被 CVPR 2024 接收。<br />
它的做法很典型：不是傻乎乎每次都把完整 DDIM 采样链跑到底，而是把 fine-tuning 重写成 editing，从而减少 full-chain sampling 成本。[8]</p>

<p>更重要的是，它面对了视频路线最现实的问题之一：</p>

<blockquote>
  <p>当时并没有一个像 ImageReward 那样成熟的视频偏好 reward model。</p>
</blockquote>

<p>所以它把图像 reward model 重用于视频，并提出：</p>

<ul>
  <li><code class="language-plaintext highlighter-rouge">Segmental Video Reward</code></li>
  <li><code class="language-plaintext highlighter-rouge">Temporally Attenuated Reward</code></li>
</ul>

<p>本质上是在说：<br />
视频 reward 不能只看“最后整段视频的一个总分”，而要想办法把 reward 更稳定地分配到片段和时序结构上。</p>

<h3 id="62-video-diffusion-alignment-via-reward-gradients把-direct-backprop-扩展到视频">6.2 Video Diffusion Alignment via Reward Gradients：把 direct backprop 扩展到视频</h3>

<p>2024 年 7 月的 <code class="language-plaintext highlighter-rouge">Video Diffusion Alignment via Reward Gradients</code> 则把 reward-backprop 路线明确推进到视频 Diffusion：<br />
既然 reward model 对 RGB 像素有稠密梯度，那就把这个梯度直接反传回视频生成过程。[9]</p>

<p>这条线的意义在于：</p>

<ul>
  <li>它说明 direct backprop 不只是图像 trick</li>
  <li>在视频这种搜索空间更大、采样更贵的场景里，低方差梯度反而更有价值</li>
</ul>

<p>所以如果你关心的是“视频 Diffusion 的 RL 怎么做”，真正要抓住的不是某个单独 paper 名，而是这两个事实：</p>

<ol>
  <li>视频 reward 一定要显式考虑时序结构。</li>
  <li>由于采样成本太高，视频场景通常更偏爱 editing、局部反传、LoRA 和稀疏 reward 设计。</li>
</ol>

<hr />

<h2 id="7-flow-matching-时代为什么-flow-grpo-是新的转折点">7. Flow Matching 时代：为什么 <code class="language-plaintext highlighter-rouge">Flow-GRPO</code> 是新的转折点</h2>

<p>到了 2025 年，主流大模型里已经不全是传统 DDPM/DDIM 了。<br />
像 SD3、FLUX 这类系统更接近 <strong>Flow Matching / Rectified Flow</strong> 范式。</p>

<p>这时旧问题又回来了：</p>

<blockquote>
  <p>RL 需要随机性和可处理的 log-prob；<br />
但 Flow Matching 的采样通常是确定性 ODE。</p>
</blockquote>

<p><code class="language-plaintext highlighter-rouge">Flow-GRPO</code> 在 2025 年 NeurIPS 上给出的回答是两个关键设计：[11]</p>

<h3 id="71-ode-to-sde-conversion">7.1 ODE-to-SDE conversion</h3>

<p>它把原本的确定性 ODE 改写成与原边际分布一致的 SDE，于是：</p>

<ul>
  <li>采样过程重新拥有随机性</li>
  <li>exploration 成立</li>
  <li>每一步转移又能写成统计上可处理的形式</li>
</ul>

<p>这一步非常重要，因为它不是在“给 Flow 模型硬套 DDPO”，而是在修补：</p>

<blockquote>
  <p>Flow 模型天然不适合直接做 reverse-process policy gradient</p>
</blockquote>

<h3 id="72-denoising-reduction">7.2 Denoising Reduction</h3>

<p><code class="language-plaintext highlighter-rouge">Flow-GRPO</code> 的第二个关键点是：<br />
训练时减少 denoising steps，推理时保留原本的高质量步数。</p>

<p>这说明一个很现实的经验：</p>

<blockquote>
  <p>RL 后训练并不一定需要最完美的生成样本，只需要足够稳定、足够可区分的 reward 信号。</p>
</blockquote>

<p>从工程角度看，这让 Flow 模型的 RL 后训练第一次真正变得可用。</p>

<hr />

<h2 id="8-截至-2026-年-4-月这条线又在往哪里走">8. 截至 2026 年 4 月，这条线又在往哪里走</h2>

<p>如果只看 2023 到 2025，你会觉得主线大概是：</p>

<p><code class="language-plaintext highlighter-rouge">DDPO -&gt; DPO / direct backprop -&gt; Flow-GRPO</code></p>

<p>但到 2026 年，研究重点已经明显转向了两个更细的问题。</p>

<h3 id="81-更好的-credit-assignment-和-rollout-复用">8.1 更好的 credit assignment 和 rollout 复用</h3>

<p><code class="language-plaintext highlighter-rouge">TreeGRPO</code>（ICLR 2026）把 denoising 过程重写成一棵搜索树，通过共享轨迹前缀来提升样本效率，并尝试解决“终止 reward 被粗暴平均分给所有时间步”的问题。[12]</p>

<p>这件事说明社区已经意识到：<br />
<strong>uniform terminal reward assignment</strong> 其实是老一代 DDPO / GRPO 风格方法的一个核心瓶颈。</p>

<h3 id="82-不再执着于-reverse-process-likelihood">8.2 不再执着于 reverse-process likelihood</h3>

<p><code class="language-plaintext highlighter-rouge">DiffusionNFT</code>（ICLR 2026 Oral）更进一步，直接质疑了“必须在 reverse sampling 里估 log-prob 才能做 online RL”这个前提。[13]</p>

<p>它的出发点很清楚：</p>

<ul>
  <li>reverse likelihood 受 solver 选择限制</li>
  <li>和 CFG 的兼容性复杂</li>
  <li>trajectory 级策略优化的成本太高</li>
</ul>

<p>所以它改走了 forward-process / flow-matching 风格目标，并声称在效率上显著超过 <code class="language-plaintext highlighter-rouge">Flow-GRPO</code>。[13]</p>

<p>这说明截至 <strong>2026 年 4 月 20 日</strong>，这个领域已经不再只是“把 PPO 套到去噪轨迹上”，而是在认真重构：</p>

<ul>
  <li>policy objective 到底该写成什么</li>
  <li>log-prob 是不是必须的</li>
  <li>terminal reward 该如何分配到各步</li>
  <li>采样器、CFG、solver 和 RL 目标能不能更自然地统一</li>
</ul>

<hr />

<h2 id="9-实践里最难的其实不是公式而是-reward-design">9. 实践里最难的其实不是公式，而是 reward design</h2>

<p>从工程上看，Diffusion RL 最大的敌人一直都不是“推不出梯度”，而是 <strong>reward hacking</strong>。</p>

<p>典型例子包括：</p>

<ul>
  <li>只优化美学分，模型会学会“糖水色”和过饱和</li>
  <li>只优化 CLIP 对齐，模型可能学会骗 VLM</li>
  <li>只优化时序一致性，视频可能变成几乎静止</li>
</ul>

<p>所以大多数可用系统都会同时做四件事：</p>

<h3 id="91-用-kl-约束守住预训练分布">9.1 用 KL 约束守住预训练分布</h3>

<p><code class="language-plaintext highlighter-rouge">DPOK</code> 之后，这几乎已经是标配。<br />
你可以把它理解为一个护栏：</p>

<blockquote>
  <p>允许模型变好，但不允许它为了 reward 快速偏航。</p>
</blockquote>

<h3 id="92-用多维-reward-而不是单分数独裁">9.2 用多维 reward 而不是单分数独裁</h3>

<p>实际系统里更常见的是加权组合：</p>

<ul>
  <li>质量</li>
  <li>对齐</li>
  <li>安全</li>
  <li>时序</li>
  <li>OCR / counting / composition</li>
</ul>

<p>单一 reward 往往最容易被钻空子。</p>

<h3 id="93-用-lora低步数反传和局部更新控制成本">9.3 用 LoRA、低步数反传和局部更新控制成本</h3>

<p>这是 <code class="language-plaintext highlighter-rouge">DRaFT</code>、<code class="language-plaintext highlighter-rouge">AlignProp</code>、视频 alignment 工作都反复验证过的经验：</p>

<ul>
  <li>不是所有参数都要动</li>
  <li>不是所有时间步都要反传</li>
  <li>不是每次都要完整链路采样</li>
</ul>

<h3 id="94-先问清楚你手里的反馈是什么">9.4 先问清楚“你手里的反馈是什么”</h3>

<p>这是最重要的实操建议：</p>

<ul>
  <li>如果 reward 可微，优先考虑 direct backprop</li>
  <li>如果只有偏好对，优先考虑 DPO 类方法</li>
  <li>如果 reward 是黑盒且可调用，再考虑 policy gradient</li>
  <li>如果模型已经是 Flow Matching，要特别注意随机性和 likelihood 的定义问题</li>
</ul>

<hr />

<h2 id="10-应该怎么选方法">10. 应该怎么选方法</h2>

<p>如果只给一个实用版判断树，我会这样用：</p>

<ol>
  <li>
    <p>reward 可微吗？
可微：优先 <code class="language-plaintext highlighter-rouge">DRaFT / AlignProp / Video Reward Gradients</code> 这类直接反传路线。</p>
  </li>
  <li>
    <p>reward 不可微，但有偏好对吗？
有：优先 <code class="language-plaintext highlighter-rouge">Diffusion-DPO / D3PO</code>。</p>
  </li>
  <li>
    <p>既没有可微 reward，也没有偏好对，只有黑盒打分器？
图像 DDPM 类：<code class="language-plaintext highlighter-rouge">DDPO / DPOK</code>。<br />
Flow Matching 类：先看 <code class="language-plaintext highlighter-rouge">Flow-GRPO</code>，再关注 <code class="language-plaintext highlighter-rouge">DiffusionNFT</code> 这类新范式。</p>
  </li>
  <li>
    <p>是视频吗？
默认把“reward 设计”和“采样成本”放在首位，再决定是 editing 式 fine-tuning、direct backprop，还是更传统的 RL 更新。</p>
  </li>
</ol>

<hr />

<h2 id="11-总结">11. 总结</h2>

<p>如果只用一句话总结这条脉络，我会写成：</p>

<blockquote>
  <p>Diffusion RL 的本质，不是把 PPO 生搬硬套到生成模型上；<br />
而是把“逐步去噪”重新看成一个可优化的决策过程，再根据你手里的反馈形式，选择 policy gradient、偏好优化或 reward backprop 这三类不同工具。</p>
</blockquote>

<p>更具体地说：</p>

<ol>
  <li><code class="language-plaintext highlighter-rouge">DDPO / DPOK</code> 解决的是“黑盒 reward 怎么直接优化”。</li>
  <li><code class="language-plaintext highlighter-rouge">Diffusion-DPO / D3PO</code> 解决的是“只有偏好对时，能不能跳过 reward model”。</li>
  <li><code class="language-plaintext highlighter-rouge">DRaFT / AlignProp</code> 解决的是“如果 reward 可微，为什么还要忍受高方差 RL”。</li>
  <li><code class="language-plaintext highlighter-rouge">InstructVideo</code> 和后续视频工作说明，视频不是图像方法的简单复制，而是一个 reward design 和效率问题都更尖锐的场景。</li>
  <li><code class="language-plaintext highlighter-rouge">Flow-GRPO</code> 则标志着这条线正式进入 Flow Matching 时代。</li>
  <li>到 2026 年，研究已经开始进一步追问：<code class="language-plaintext highlighter-rouge">log-prob</code> 是否必须、credit assignment 能否更细、采样器与 RL 目标能否更统一。</li>
</ol>

<p>所以从历史上看，Diffusion 模型的“强化学习过程”并不是一套固定算法，而是一条不断重新定义<strong>策略、反馈、轨迹和约束</strong>的演化路线。</p>

<hr />

<h2 id="参考资料">参考资料</h2>

<ol>
  <li>
    <p>Xu, Liu, Wu, Tong, Li, Ding, Tang, Dong. <em>ImageReward: Learning and Evaluating Human Preferences for Text-to-Image Generation</em>. arXiv 2023.<br />
<a href="https://arxiv.org/abs/2304.05977">arXiv</a></p>
  </li>
  <li>
    <p>Black, Janner, Du, Kostrikov, Levine. <em>Training Diffusion Models with Reinforcement Learning</em>. arXiv 2023.<br />
<a href="https://arxiv.org/abs/2305.13301">arXiv</a> | <a href="https://rl-diffusion.github.io/">Project</a></p>
  </li>
  <li>
    <p>Fan, Watkins, Du, Liu, Ryu, Boutilier, Abbeel, Ghavamzadeh, K. Lee, K. Lee. <em>DPOK: Reinforcement Learning for Fine-tuning Text-to-Image Diffusion Models</em>. arXiv 2023.<br />
<a href="https://arxiv.org/abs/2305.16381">arXiv</a></p>
  </li>
  <li>
    <p>Clark, Vicol, Swersky, Fleet. <em>Directly Fine-Tuning Diffusion Models on Differentiable Rewards</em>. arXiv 2023 / ICLR 2024.<br />
<a href="https://arxiv.org/abs/2309.17400">arXiv</a></p>
  </li>
  <li>
    <p>Prabhudesai, Goyal, Pathak, Fragkiadaki. <em>Aligning Text-to-Image Diffusion Models with Reward Backpropagation</em>. arXiv 2023.<br />
<a href="https://arxiv.org/abs/2310.03739">arXiv</a> | <a href="https://align-prop.github.io/">Project</a></p>
  </li>
  <li>
    <p>Wallace, Dang, Rafailov, Zhou, Lou, Purushwalkam, Ermon, Xiong, Joty, Naik. <em>Diffusion Model Alignment Using Direct Preference Optimization</em>. arXiv 2023.<br />
<a href="https://arxiv.org/abs/2311.12908">arXiv</a></p>
  </li>
  <li>
    <p>Yang, Tao, Lyu, Ge, Chen, Li, Shen, Zhu, Li. <em>Using Human Feedback to Fine-tune Diffusion Models without Any Reward Model</em>. arXiv 2023 / CVPR 2024.<br />
<a href="https://arxiv.org/abs/2311.13231">arXiv</a></p>
  </li>
  <li>
    <p>Yuan, Zhang, Wang, Wei, Feng, Pan, Zhang, Liu, Albanie, Ni. <em>InstructVideo: Instructing Video Diffusion Models with Human Feedback</em>. arXiv 2023 / CVPR 2024.<br />
<a href="https://arxiv.org/abs/2312.12490">arXiv</a></p>
  </li>
  <li>
    <p>Prabhudesai, Mendonca, Qin, Fragkiadaki, Pathak. <em>Video Diffusion Alignment via Reward Gradients</em>. arXiv 2024.<br />
<a href="https://arxiv.org/abs/2407.08737">arXiv</a></p>
  </li>
  <li>
    <p>Uehara, Zhao, Biancalani, Levine. <em>Understanding Reinforcement Learning-Based Fine-Tuning of Diffusion Models: A Tutorial and Review</em>. arXiv 2024.<br />
<a href="https://arxiv.org/abs/2407.13734">arXiv</a></p>
  </li>
  <li>
    <p>Liu, Liu, Liang, Li, Liu, Wang, Wan, Zhang, Ouyang. <em>Flow-GRPO: Training Flow Matching Models via Online RL</em>. NeurIPS 2025.<br />
<a href="https://openreview.net/forum?id=oCBKGw5HNf">OpenReview</a></p>
  </li>
  <li>
    <p>Ding, Ye. <em>TreeGRPO: Tree-Advantage GRPO for Online RL Post-Training of Diffusion Models</em>. ICLR 2026.<br />
<a href="https://openreview.net/forum?id=3rZdp4TmUb">OpenReview</a></p>
  </li>
  <li>
    <p>Zheng, Chen, Ye, Wang, Zhang, Jiang, Su, Ermon, Zhu, Liu. <em>DiffusionNFT: Online Diffusion Reinforcement with Forward Process</em>. ICLR 2026 Oral.<br />
<a href="https://openreview.net/forum?id=VJZ477R89F">OpenReview</a></p>
  </li>
</ol>]]></content><author><name>李勇志 (Yongzhi Li)</name><email>yongzhili@pku.edu.cn</email></author><category term="blog" /><category term="diffusion" /><category term="reinforcement learning" /><category term="generative model" /><category term="machine learning" /><category term="computer vision" /><summary type="html"><![CDATA[Diffusion 模型最初是按“去噪 MSE / 似然近似”来训练的，但真正上线时，我们更关心的往往不是似然，而是：]]></summary></entry><entry><title type="html">让大模型快 8 倍：从投机解码到 DDTree 的完整原理</title><link href="https://liyongzhi.xyz/posts/2026/04/speculative-decoding/" rel="alternate" type="text/html" title="让大模型快 8 倍：从投机解码到 DDTree 的完整原理" /><published>2026-04-20T00:00:00+08:00</published><updated>2026-04-20T00:00:00+08:00</updated><id>https://liyongzhi.xyz/posts/2026/04/blog-post-speculative-decoding</id><content type="html" xml:base="https://liyongzhi.xyz/posts/2026/04/speculative-decoding/"><![CDATA[<blockquote>
  <p>本文从零开始，带你理解 LLM 推理加速的核心思路，读完之后你会明白：大模型为什么慢、投机解码如何加速、为什么加速后输出质量完全不变，以及 DDTree 这篇 2026 年的新论文究竟做了什么创新。</p>
</blockquote>

<hr />

<h2 id="1-大模型推理为什么慢">1. 大模型推理为什么慢？</h2>

<p>你每次跟 ChatGPT 或 Claude 对话，它生成文字的方式其实非常朴素——<strong>每次只生成一个 token（词片段），然后把这个 token 加进上下文，再生成下一个，如此往复</strong>。</p>

<p>这种方式叫做<strong>自回归解码（Autoregressive Decoding）</strong>。它的问题在于天然的串行性：第 $n$ 个 token 必须等第 $n-1$ 个 token 生成完才能开始，完全无法并行。</p>

<p>更麻烦的是，现代大模型动辄数百亿参数，每生成一个 token 就要做一次完整的前向传播，把所有参数都过一遍。而 GPU 最擅长的恰恰是<strong>大批量并行计算</strong>——生成单个 token 这件事，对 GPU 来说几乎是一种浪费，大量算力处于闲置状态。</p>

<p>用一个类比：这就好比你雇了一个有 100 条生产线的工厂，却每次只让它生产一个零件，然后等零件出来之后再决定下一个生产什么。工厂的产能大部分都在空转。</p>

<p>那有没有办法，让大模型一次性并行生成多个 token，又不影响输出质量？</p>

<p>这就是<strong>投机解码（Speculative Decoding）</strong>的出发点。</p>

<hr />

<h2 id="2-投机解码用小模型猜用大模型验">2. 投机解码：用小模型”猜”，用大模型”验”</h2>

<p>投机解码的核心思想简单到令人惊喜：</p>

<blockquote>
  <p>用一个<strong>轻量的草稿模型（Draft Model）</strong> 先快速生成多个候选 token，再让<strong>大的目标模型（Target Model）</strong> 一次性并行验证这些候选 token，接受正确的，拒绝错误的。</p>
</blockquote>

<p>为什么这能加速？关键在于一个不对称性：<strong>大模型验证多个 token 和验证单个 token，所需要的计算量几乎相同</strong>。这是因为验证本质上是一次 prefill（并行处理整个序列），而不是逐个自回归生成。</p>

<p>具体流程是这样的：</p>

<ol>
  <li><strong>草稿阶段</strong>：草稿模型连续生成 $k$ 个 token（比如 4 个），速度很快</li>
  <li><strong>验证阶段</strong>：目标模型把这 $k$ 个候选 token 一次性并行处理，输出对每个位置的概率判断</li>
  <li><strong>接受/拒绝</strong>：从左到右逐个判断，接受的 token 保留，遇到第一个被拒绝的 token 就停止，后面的全部丢弃</li>
  <li><strong>下一轮</strong>：从被拒绝的位置重新开始草稿</li>
</ol>

<p>如果草稿模型猜对了 3 个 token，那这一轮相当于大模型用一次前向传播的时间，产出了 3 个高质量 token，而不是通常的 1 个。加速比可以非常显著。</p>

<h3 id="草稿模型的-qx-究竟是什么">草稿模型的 $q(x)$ 究竟是什么？</h3>

<p>一个常见的细节困惑：草稿模型给每个 token 的概率 $q(x)$，到底指什么？</p>

<ul>
  <li><strong>Temperature = 0（贪心）</strong>：草稿直接取 argmax 那个 token，$q(x)$ 就是该 token 在 softmax 后对应的那个最大概率值</li>
  <li><strong>Temperature &gt; 0（采样）</strong>：按分布采样一个 token，$q(x)$ 就是采到的那个 token 对应的概率值</li>
  <li><strong>DFlash 这类块扩散草稿</strong>：一次 forward pass 直接输出每个位置的完整边际分布，$q(x)$ 就是该分布在对应 token 上的概率值</li>
</ul>

<p>无论哪种情况，$q(x)$ 都是”草稿模型对它自己选出的那个 token 所赋予的概率”，而不是整个分布。这一点在后面理解接受/拒绝规则时非常重要。</p>

<hr />

<h2 id="3-最关键的问题输出质量真的不变吗">3. 最关键的问题：输出质量真的不变吗？</h2>

<p>这是整个方法最精妙的地方，也是很多人最困惑的地方。</p>

<p><strong>答案是：严格不变，数学上可以证明。</strong></p>

<p>关键在于接受/拒绝的规则设计。对于草稿模型提出的某个 token $x$：</p>

<ul>
  <li>草稿模型给它的概率是 $q(x)$</li>
  <li>目标模型给它的概率是 $p(x)$</li>
</ul>

<p><strong>以概率 $\min!\left(1,\, \frac{p(x)}{q(x)}\right)$ 接受这个 token。</strong></p>

<p>直觉很简单：如果目标模型认为这个 token 的概率至少和草稿模型一样高（$p \geq q$），就无条件接受；如果目标模型认为它概率更低（$p &lt; q$），就按比例拒绝，避免这个 token 被过度采样。</p>

<p><strong>被拒绝时怎么办？</strong> 不是直接放弃，而是从一个”残差分布”里重新采样：</p>

\[p'(x) = \frac{\max(0,\; p(x) - q(x))}{1 - \beta}, \quad \text{其中 } \beta = \sum_y \min(p(y), q(y))\]

<p>这个残差分布的含义是：草稿模型过度提名了某些 token（$q &gt; p$），那些 token 残差为 0，不再有机会；而那些被草稿欠缺代表的 token（$p &gt; q$），按亏欠量分配权重，作为补偿。</p>

<h3 id="一个具体的数字例子">一个具体的数字例子</h3>

<p>抽象公式不好懂，换一个只有 5 个 token 的小词表来算一遍。假设目标分布 $p$ 和草稿分布 $q$ 如下，草稿模型抽出来的 token 是 “cat”：</p>

<table>
  <thead>
    <tr>
      <th>Token</th>
      <th>$p(x)$</th>
      <th>$q(x)$</th>
      <th>$\min(p,q)$</th>
      <th>$\max(0,\, p-q)$</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>cat</td>
      <td>0.50</td>
      <td>0.70</td>
      <td>0.50</td>
      <td>0</td>
    </tr>
    <tr>
      <td>dog</td>
      <td>0.20</td>
      <td>0.10</td>
      <td>0.10</td>
      <td>0.10</td>
    </tr>
    <tr>
      <td>sat</td>
      <td>0.15</td>
      <td>0.05</td>
      <td>0.05</td>
      <td>0.10</td>
    </tr>
    <tr>
      <td>on</td>
      <td>0.10</td>
      <td>0.10</td>
      <td>0.10</td>
      <td>0</td>
    </tr>
    <tr>
      <td>mat</td>
      <td>0.05</td>
      <td>0.05</td>
      <td>0.05</td>
      <td>0</td>
    </tr>
  </tbody>
</table>

<p><strong>第一步：判断是否接受 “cat”</strong>。因为 $p(\text{cat})/q(\text{cat}) = 0.50/0.70 \approx 0.71 &lt; 1$，以 71% 的概率接受，29% 的概率拒绝。</p>

<p><strong>第二步：一旦被拒绝，从残差分布 $p’$ 重新采样</strong>。先算 $\beta = \sum \min(p, q) = 0.50 + 0.10 + 0.05 + 0.10 + 0.05 = 0.80$，因此 $1 - \beta = 0.20$。把 $\max(0,\, p-q)$ 这一列每个元素除以 $0.20$：</p>

<table>
  <thead>
    <tr>
      <th>Token</th>
      <th>残差概率 $p’(x)$</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>cat</td>
      <td>0 / 0.20 = <strong>0</strong></td>
    </tr>
    <tr>
      <td>dog</td>
      <td>0.10 / 0.20 = <strong>0.50</strong></td>
    </tr>
    <tr>
      <td>sat</td>
      <td>0.10 / 0.20 = <strong>0.50</strong></td>
    </tr>
    <tr>
      <td>on</td>
      <td>0 / 0.20 = <strong>0</strong></td>
    </tr>
    <tr>
      <td>mat</td>
      <td>0 / 0.20 = <strong>0</strong></td>
    </tr>
  </tbody>
</table>

<p>于是：被拒绝后，以 50/50 的概率从 “dog” 和 “sat” 中补采一个。</p>

<h3 id="为什么残差分布不是把-cat-去掉再归一化">为什么残差分布不是”把 cat 去掉再归一化”？</h3>

<p>这里最容易踩的坑：很多人第一反应是”既然 cat 被拒了，就把 cat 从 $p$ 里划掉，剩下的 ${0.20, 0.15, 0.10, 0.05}$ 归一化一下（和是 0.50），按这个分布采一个就好”。</p>

<p>但那样算出来 “on” 会得到 $0.10/0.50 = 0.20$ 的概率。而残差分布给 “on” 的概率是 <strong>0</strong>。差别在哪？</p>

<p>关键是 “on” 满足 $p(\text{on}) = q(\text{on}) = 0.10$——草稿模型已经给 “on” 分配了完全正确比例的概率质量，不多不少。再在补采里给它机会，就会让 “on” 在最终输出分布里超出 0.10。</p>

<p><strong>残差分布只补偿那些被草稿欠缺代表的 token（$p &gt; q$）</strong>，而不是盲目地把剩下的 $p$ 归一化。这才是最终边际分布严格等于 $p$ 的关键。</p>

<h3 id="为什么整体输出分布仍然等于-p">为什么整体输出分布仍然等于 $p$？</h3>

<p>把直接接受和补采的贡献加起来：</p>

\[P_{\text{out}}(x) = \underbrace{\min(p(x),\, q(x))}_{\text{直接接受贡献}} + \underbrace{\max(0,\, p(x) - q(x))}_{\text{补采贡献}} = p(x)\]

<p>这个等式对任意 $p, q$ 都严格成立，无需任何假设。草稿模型质量只影响速度（草稿越准，$\beta$ 越大，补采越少），<strong>永远不影响输出正确性</strong>。</p>

<p>这种技术叫做<strong>拒绝采样（Rejection Sampling）</strong>，是概率论里的经典工具，被投机解码巧妙地应用到了 LLM 推理加速上。</p>

<hr />

<h2 id="4-级联丢弃为什么第一个错了后面全废">4. 级联丢弃：为什么第一个错了，后面全废</h2>

<p>这里有个重要细节需要理解。</p>

<p>草稿模型生成 token 序列时，每个 token 都是条件于前面的 token 生成的。比如草稿模型生成的是 “The cat sat on the mat”，其中 “sat” 是在”已知前两个 token 是 The, cat”的条件下预测的；”on” 是在”已知前三个是 The, cat, sat”的条件下预测的。</p>

<p>一旦 “cat” 被目标模型拒绝，目标模型会在那个位置补采一个不同的 token，比如 “dog”。此时，原草稿里的 “sat” 是基于错误前缀 “The, cat” 生成的 $q(\text{sat} \mid \text{The, cat})$，但真实需要的条件概率是 $q(\text{sat} \mid \text{The, dog})$——这是两个完全不同的分布，原来的 $q$ 值对新上下文毫无参考价值。</p>

<p>因此：<strong>第 $i$ 个位置被拒绝，位置 $i+1$ 及之后的所有草稿 token 必须全部丢弃</strong>，不再验证。</p>

<p>这也意味着：</p>

<ul>
  <li>草稿模型的质量极其重要——第一个 token 就被拒绝，整轮草稿全部白费</li>
  <li>优化的核心指标是<strong>期望接受长度</strong>（expected accepted length）：拒绝发生得越晚，这一轮能白捡的 token 越多</li>
  <li>从 $\beta = \sum_x \min(p(x), q(x))$ 也能看出：当 $q \equiv p$ 时 $\beta = 1$，每个 token 都被接受；当 $q$ 与 $p$ 完全不重叠时 $\beta \to 0$，几乎每个 token 都被拒，速度反而比纯自回归还慢（多跑了一次草稿模型）</li>
</ul>

<p>正是因此，专门为投机解码训练的草稿模型（比如 EAGLE、DFlash）要比通用小模型效果好得多：它们被明确训练成”输出分布尽可能像目标模型”，而不仅仅是”在自己的训练集上 loss 最小”。</p>

<hr />

<h2 id="5-dflash块扩散草稿模型">5. DFlash：块扩散草稿模型</h2>

<p>传统的草稿模型和目标模型一样，也是自回归的——要生成 4 个草稿 token，就要跑 4 次前向传播。自回归草稿的问题是：虽然模型小，但”串行”这个结构性瓶颈一点没变。EAGLE、EAGLE-2、EAGLE-3 这类主流草稿模型都是自回归派系的优化（更轻的 head、共享 hidden state 等），它们在单步上做得很快，但仍然要跑 $k$ 次 forward 才能出 $k$ 个草稿。</p>

<p><strong>DFlash</strong> 走了完全不同的一条路：<strong>块扩散（Block Diffusion）</strong>。它基于 Arriola 等人在 ICLR 2025 提出的 <code class="language-plaintext highlighter-rouge">BD3-LM</code>（Block Discrete Denoising Diffusion Language Model）范式，并针对投机解码场景做了专门适配。</p>

<h3 id="51-核心范式块之间自回归块内扩散">5.1 核心范式：块之间自回归，块内扩散</h3>

<p>BD3-LM 的一句话总结：<strong>块与块之间保持自回归，块内部用离散扩散一次性并行预测所有位置</strong>。</p>

<p>具体来说：</p>

<ul>
  <li>把要生成的长序列切成若干固定长度 $L_b$ 的块（例如 $L_b = 4, 8, 16$）</li>
  <li>块与块之间是严格的因果顺序：第 $j$ 块的生成必须在第 $j-1$ 块完成之后进行</li>
  <li>块内部的 $L_b$ 个位置通过<strong>离散扩散</strong>一次性联合预测</li>
</ul>

<p>这种混合结构解决了纯 diffusion 语言模型的两大痛点：<strong>固定长度限制</strong>和<strong>无法做 KV cache</strong>。因为块之间还是自回归，已经生成好的块就可以像普通 LLM 一样把 K/V 缓存下来，后续块只需要在缓存之上再做 attention——这和目标模型的 KV cache 使用完全兼容。</p>

<h3 id="52-训练目标块内的-masked-diffusion-elbo">5.2 训练目标：块内的 masked diffusion ELBO</h3>

<p>训练的时候，对每一个训练样本：</p>

<ol>
  <li>采一个噪声水平 $t \in [0, 1]$</li>
  <li>在<strong>当前块</strong>的 $L_b$ 个位置上，独立地以概率 $t$ 把每个 token 替换成特殊的 <code class="language-plaintext highlighter-rouge">[MASK]</code> token</li>
  <li>保留前面所有完整的块（上下文），只在当前块的位置上加噪</li>
  <li>模型的任务是从被 mask 的块中恢复原始 token</li>
</ol>

<p>损失函数本质上是离散 diffusion 的 ELBO，形式上可以理解成一个加权的交叉熵：</p>

\[\mathcal{L}_{\text{BD}} = \mathbb{E}_{t, x}\left[\sum_{i \in \text{当前块}} \mathbb{1}[x_i^{(t)} = \text{MASK}] \cdot \frac{1}{t} \cdot \log p_\theta(x_i \mid x^{(t)}, x_{&lt;\text{block}})\right]\]

<p>几个要点：</p>

<ul>
  <li>只在被 mask 的位置计算 loss（未被 mask 的位置模型只是看到了答案）</li>
  <li>$1/t$ 是重要性权重，把不同噪声水平的贡献归一化</li>
  <li>模型同时看到了”前面已完成的块”和”当前块的部分观测”，需要学会利用这两者一起做预测</li>
</ul>

<h3 id="53-架构块因果-attentionblock-causal-attention">5.3 架构：块因果 attention（Block-Causal Attention）</h3>

<p>BD3-LM 用的还是标准 Transformer，但 attention mask 变了：</p>

<ul>
  <li><strong>同一块内</strong>：全 attention（每个位置可以看到块内所有位置，包括被 mask 的）</li>
  <li><strong>跨块</strong>：因果 attention（当前块可以看前面所有块，反之不行）</li>
</ul>

<p>画出来像这样：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>            块1  块2  块3
       块1 [■ ■ ■][□ □ □][□ □ □]
       块2 [■ ■ ■][■ ■ ■][□ □ □]
       块3 [■ ■ ■][■ ■ ■][■ ■ ■]
</code></pre></div></div>

<p>（■ 可见，□ 不可见；每块 3 个位置）</p>

<p>这个结构让 KV cache 成为可能：块 1 算完之后，它的 K/V 就固定了；块 2 生成时复用块 1 的缓存，只需要新计算块 2 这 $L_b$ 个位置的 K/V；块 3 再叠加。跟自回归 LLM 的 KV cache 行为一致。</p>

<h3 id="54-采样t-步去噪一次得到一个块">5.4 采样：T 步去噪一次得到一个块</h3>

<p>给定前面的块做上下文，生成当前块的流程是：</p>

<ol>
  <li>把当前块的所有 $L_b$ 个位置初始化为 <code class="language-plaintext highlighter-rouge">[MASK]</code></li>
  <li>进行 $T$ 步去噪：每一步，模型看一眼当前的”部分 mask 状态”，对每个仍然是 <code class="language-plaintext highlighter-rouge">[MASK]</code> 的位置输出一个 softmax 分布，按某种策略（如置信度 top-k 或随机）选一部分位置把 mask 替换成采样到的 token</li>
  <li>直到所有位置都被填上，当前块完成</li>
</ol>

<p>关键：<strong>最后一步去噪得到的 $L_b$ 个 softmax 分布，就是 DDTree 要用的”每个位置的概率分布”</strong>。也就是说，DFlash 在块生成的最终一步 forward pass 里，同时暴露了块内每个位置的完整后验分布——这正是后面 DDTree 能用它构树的原因。</p>

<p>BD3-LM 原论文在一般生成时用 $T = 5000$ 步（质量优先），但在 draft 场景下完全不需要这么多——DFlash 把 $T$ 降到很小（论文级别通常 $T \in {1, 2, 4}$ 甚至 single-step），用”一次去噪直接输出块”的方式保证草稿延迟最低。</p>

<h3 id="55-块大小的权衡">5.5 块大小的权衡</h3>

<p>块大小 $L_b$ 是 DFlash 最关键的超参数：</p>

<table>
  <thead>
    <tr>
      <th>块大小</th>
      <th>草稿模型调用次数</th>
      <th>单次输出 token 数</th>
      <th>被接受长度期望</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>小（$L_b=4$）</td>
      <td>多</td>
      <td>4</td>
      <td>短</td>
    </tr>
    <tr>
      <td>中（$L_b=8$）</td>
      <td>中</td>
      <td>8</td>
      <td>中</td>
    </tr>
    <tr>
      <td>大（$L_b=16$）</td>
      <td>少</td>
      <td>16</td>
      <td>可能长也可能第一个就被拒</td>
    </tr>
  </tbody>
</table>

<p>BD3-LM 论文显示，$L_b$ 越大，困惑度恶化越明显（diffusion 建模 16 个位置比建模 4 个位置难得多）；但 $L_b$ 越小，草稿模型要被调用更多次才能”推进”相同距离。投机解码里通常选 $L_b \in {4, 8}$，在单次出 token 数和接受长度之间取平衡。</p>

<h3 id="56-dflash-针对投机解码的特化">5.6 DFlash 针对投机解码的特化</h3>

<p>把 block diffusion 作为草稿模型，DFlash 相比原始 BD3-LM 做了几项适配：</p>

<ul>
  <li><strong>目标模型蒸馏</strong>：训练时不只最大化 block diffusion ELBO，还加一项”草稿输出分布尽可能贴近目标模型”的 KL 损失，让 $\beta = \sum \min(p, q)$ 尽可能大</li>
  <li><strong>极少去噪步数</strong>：从 5000 步降到 1–4 步，在 draft 阶段用”近似一步到位”换延迟</li>
  <li><strong>输出带温度的完整分布</strong>：不像 BD3-LM 需要最终离散 token，DFlash 保留最后一步 softmax 的完整分布以供 DDTree 这类下游使用</li>
  <li><strong>共享 tokenizer 和 context</strong>：和目标模型使用同一份 tokenizer，KV cache 可以部分对齐复用</li>
</ul>

<h3 id="57-一次-forward-pass-实际输出什么">5.7 一次 forward pass 实际输出什么？</h3>

<p>有了上面这些背景，具体看 DFlash 一次前向传播（假设 $L_b = 16$，$T = 1$）的输出：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>输入：已经生成的 prefix tokens + 16 个 [MASK] 占位符
       │
       ▼（一次 block-causal transformer forward pass）
       │
输出：16 个 softmax 分布（每个分布大小 = 词表大小）

位置 1 的分布：cat=0.70, dog=0.20, sat=0.08, ...
位置 2 的分布：sat=0.60, lay=0.30, on=0.08, ...
位置 3 的分布：on=0.65, at=0.20, the=0.12, ...
...
位置 16 的分布：...
</code></pre></div></div>

<p>然后 DFlash 从每个位置的分布里各采样一个 token，拼成草稿序列，交给目标模型验证。</p>

<p>这比自回归草稿快很多——不管块有多长，草稿生成只需要一次前向传播。DFlash 在实测中已经超越了 EAGLE-3 等强大的自回归草稿模型，成为投机解码的领先方案。</p>

<p>但 DFlash 有一个明显的浪费：<strong>每个位置的概率分布包含丰富的候选信息，却只用了最高概率的一个 token</strong>。</p>

<p>比如位置 1，cat 概率 0.70，dog 概率 0.20。DFlash 只采了 cat，dog 的可能性完全被忽略了。如果目标模型拒绝了 cat，这一轮就白费了，而其实 dog 也是一个很有希望的候选。</p>

<p>能不能把这些信息都利用起来？</p>

<hr />

<h2 id="6-ddtree从一条路到一棵树">6. DDTree：从”一条路”到”一棵树”</h2>

<p><strong>DDTree（Diffusion Draft Tree）</strong> 就是这个问题的答案。</p>

<p>核心思路极其直觉：既然 DFlash 一次 forward pass 已经给出了每个位置的完整分布，为什么不用这些分布构建<strong>多条候选路径</strong>，组成一棵树，让目标模型一次性验证整棵树？</p>

<h3 id="61-草稿树长什么样">6.1 草稿树长什么样？</h3>

<p>树的根节点是上一轮结尾的 token，然后从每个位置的分布里选出 top-K 个候选，展开成多条分支：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>           root（…The）
          /            \
      cat (0.70)      dog (0.20)
      /      \              \
  sat(0.42) lay(0.21)    sat(0.12)
    /
 on(0.27)
</code></pre></div></div>

<p>每条从根到叶子的路径，就是一个完整的草稿序列。目标模型可以同时验证所有路径，找出其中最长的被接受前缀。</p>

<h3 id="62-节点预算-b速度与收益的权衡">6.2 节点预算 B：速度与收益的权衡</h3>

<p>树里的节点越多，覆盖的候选路径越丰富，被接受的 token 越多。但同时，目标模型验证时需要处理更多节点，开销也更大。</p>

<p>DDTree 用一个<strong>节点预算 $B$</strong> 来控制树的大小——$B$ 就是整棵树里最多允许几个节点槽。在 $B$ 范围内尽可能选出最有价值的节点。论文实验表明，最优预算大约是 $B = 512$。</p>

<h3 id="63-怎么在预算内选出最优树best-first-heap">6.3 怎么在预算内选出最优树？Best-First Heap</h3>

<p>每条路径的联合概率 = 各位置边际概率之积（因为 DFlash 输出的是独立的按位置分布）：</p>

\[P(\text{cat} \to \text{sat} \to \text{on}) = 0.70 \times 0.60 \times 0.65 \approx 0.27\]

<p>DDTree 用一个<strong>最优先堆（Best-First Heap）</strong> 来贪心构建最优树：</p>

<ol>
  <li>初始把根节点的所有子节点候选放入堆，按概率排序</li>
  <li>每次弹出概率最高的叶子节点，将其子节点（下一位置的 top-K）推入堆</li>
  <li>重复，直到节点总数达到预算 $B$</li>
</ol>

<p>整个过程<strong>完全基于 DFlash 同一次 forward pass 的输出</strong>，不需要再调草稿模型，时间复杂度仅 $O(B \log(B \cdot K))$，极快。可以证明，这个算法在预算 $B$ 内构建出的树，能<strong>最大化期望被接受的 token 数量</strong>。</p>

<h3 id="64-ancestor-only-attention-mask一次并行验证整棵树">6.4 Ancestor-only Attention Mask：一次并行验证整棵树</h3>

<p>树构建好之后，把所有节点按 BFS/DFS 顺序拉平成一维序列，输入目标模型做 prefill。但需要一个特殊的 attention mask：<strong>每个节点只能 attend 到它的祖先节点</strong>，不能看到兄弟或表亲节点。</p>

<p>这保证了每条路径的验证是独立正确的——就好像每条路径单独 prefill，但通过共享公共祖先的 KV（兄弟路径共享同一份祖先的 KV，只存一份），实际上一次 forward pass 就完成了所有路径的验证，效率极高。</p>

<p><strong>kernel 实现的坑</strong>：这种任意稀疏的 attention mask，标准 FlashAttention 并不原生支持，而且树结构每一轮都在变化，无法提前编译。实践上通常用 <strong>Triton 实现的 FlashAttention 变体</strong>，先把 mask 预处理为 block 格式，然后在 kernel 里跳过全零 block 的计算——用稀疏性换效率。这是 DDTree 工程落地的关键之一。</p>

<h3 id="65-贪心走树一个具体例子">6.5 贪心走树：一个具体例子</h3>

<p>验证完成后，<strong>贪心沿最长接受路径走树</strong>：从根出发，每一步看当前节点的哪些孩子被目标模型接受了，如果接受了就往下走一层，遇到第一个被拒绝的节点就停止，把该路径上所有接受的 token 一次性提交。</p>

<p>举个例子，假设树如下：</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>         root
       /      \
    cat ✓    dog ✗
   /    \
sat ✓  lay ✗
  /
on ✗
</code></pre></div></div>

<p>走树过程：从 root 出发 → “cat” 被接受，进入 cat → “sat” 被接受，进入 sat → “on” 被拒绝，停止。本轮提交 <code class="language-plaintext highlighter-rouge">cat → sat</code>（2 个 token），并在 “on” 的位置让目标模型额外采一个 bonus token 作为下一轮草稿的起点。被拒绝节点的子树（比如 “dog” 下面挂着的那一整条分支）完全丢弃。</p>

<h3 id="66-kv-cache-怎么回收">6.6 KV Cache 怎么回收？</h3>

<p>这是树形投机解码最优雅的地方——<strong>根本不需要显式”删除”</strong>。</p>

<p>验证本质是一次 prefill，prefill 结束后只把接受路径上的节点 KV 追加到主序列的 KV Cache 末尾。拒绝节点的 KV 值虽然在 prefill 里算过，但<strong>从未被持久化</strong>到 KV Cache 里，它们只是 attention 计算的中间张量，离开 kernel 就自动废弃，无需任何显式释放。</p>

<h3 id="67-batch-推理的挑战">6.7 Batch 推理的挑战</h3>

<p>单请求的 DDTree 已经比较清楚，但工业落地必然要支持多请求并发。方案是把多个请求的树节点 concat 成一个长序列做一次 prefill，请求之间用 attention mask 完全隔离——A 请求的任何节点看不到 B 请求的任何节点。</p>

<p>DDTree 论文的实现聚焦单请求，<strong>真正生产级的多请求 batch 需要类似 vLLM 的调度层配合树形 attention kernel</strong>，包括不同请求之间 KV Cache 的 paged 管理、不同请求树预算的动态分配等。这是当前投机解码生产落地最主要的工程难点。</p>

<hr />

<h2 id="7-效果全面超越-dflash">7. 效果：全面超越 DFlash</h2>

<p>DDTree 在 60 个数据集-模型-温度组合上，全部优于原始 DFlash。以 Qwen3-8B 在 temperature=0 时为例：</p>

<table>
  <thead>
    <tr>
      <th>数据集</th>
      <th>DFlash</th>
      <th>DDTree</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>MATH-500</td>
      <td>5.56×</td>
      <td><strong>7.52×</strong></td>
    </tr>
    <tr>
      <td>HumanEval</td>
      <td>4.84×</td>
      <td><strong>6.90×</strong></td>
    </tr>
    <tr>
      <td>GSM8K</td>
      <td>4.78×</td>
      <td><strong>6.75×</strong></td>
    </tr>
    <tr>
      <td>AIME’24</td>
      <td>—</td>
      <td><strong>7.3×</strong></td>
    </tr>
    <tr>
      <td>AIME’25</td>
      <td>—</td>
      <td><strong>7.2×</strong></td>
    </tr>
  </tbody>
</table>

<p>更大的 30B 代码模型上，HumanEval 加速比达到 <strong>8.22×</strong>，接近自回归解码速度的 8 倍。</p>

<p>而且 DDTree 是<strong>完全无损的</strong>——所有 token 都经过目标模型验证，输出分布与原始自回归解码在数学上完全等价。论文代码基于 HuggingFace Transformers 开源实现，结果可复现。</p>

<hr />

<h2 id="8-总结一张图理解全套系统">8. 总结：一张图理解全套系统</h2>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[自回归解码的困境]
  每次只生成 1 个 token → GPU 闲置 → 延迟高

[投机解码的思路]
  草稿模型快速提出候选 → 目标模型并行验证 → 数学上等价于原始分布

[接受/拒绝规则]
  以 min(1, p/q) 的概率接受草稿 token
  被拒绝时从残差分布 p'(x) 补采 → 无论如何输出都等于 p(x)

[DFlash 的创新]
  块扩散：一次 forward pass 输出整块的按位置分布
  比自回归草稿快，但每个位置只用了一个 token

[DDTree 的创新]
  利用 DFlash 的完整分布，在节点预算 B 内用 Best-First Heap 构建最优草稿树
  目标模型用 ancestor-only attention mask 一次验证整棵树
  贪心走树取最长接受路径 → 每轮接受更多 token → 速度提升 35%+
</code></pre></div></div>

<p>投机解码的整个故事，本质上是一次关于”如何把 GPU 的并行能力用足”的精妙设计。从朴素的串行解码，到草稿-验证的并行框架，再到树形结构对信息的充分利用，每一步都在回答同一个问题：<strong>怎样让大模型在不降低质量的前提下，尽可能快地生成文字</strong>。</p>

<hr />

<p><em>参考论文：Ringel &amp; Romano, “Accelerating Speculative Decoding with Block Diffusion Draft Trees”, arXiv 2604.12989, 2026.</em></p>]]></content><author><name>李勇志 (Yongzhi Li)</name><email>yongzhili@pku.edu.cn</email></author><category term="blog" /><category term="llm" /><category term="speculative decoding" /><category term="inference acceleration" /><category term="diffusion" /><category term="machine learning" /><summary type="html"><![CDATA[本文从零开始，带你理解 LLM 推理加速的核心思路，读完之后你会明白：大模型为什么慢、投机解码如何加速、为什么加速后输出质量完全不变，以及 DDTree 这篇 2026 年的新论文究竟做了什么创新。]]></summary></entry><entry><title type="html">Python 利用selenium 控制浏览器自动提交表单</title><link href="https://liyongzhi.xyz/posts/2023/02/python-selenium-auto-form-submit/" rel="alternate" type="text/html" title="Python 利用selenium 控制浏览器自动提交表单" /><published>2023-02-07T00:00:00+08:00</published><updated>2023-02-07T00:00:00+08:00</updated><id>https://liyongzhi.xyz/posts/2023/02/blog</id><content type="html" xml:base="https://liyongzhi.xyz/posts/2023/02/python-selenium-auto-form-submit/"><![CDATA[<hr />

<hr />

<script async="" src="//busuanzi.ibruce.info/busuanzi/2.3/busuanzi.pure.mini.js"></script>

<div>
<div class="button01">
      <visited_a href="#" display:inline=""><span id="busuanzi_container_site_pv">你是第<span id="busuanzi_value_site_pv"></span>位访客~</span></visited_a>
      <visited_p class="top">٩(๑^o^๑)۶</visited_p>
      <visited_p class="bottom">Σ(っ °Д °;)っ被你发现了！</visited_p>
</div>
<img align="center" width="100" src="https://liyongzhi.xyz/images/static/take_me.gif" alt="" display:inline="" />
</div>

<hr />

<h2 id="python-利用selenium-自动控制浏览器提交表单">Python 利用selenium 自动控制浏览器提交表单</h2>

<h3 id="前期准备">前期准备</h3>

<ol>
  <li>下载安装chrome webdriver <code class="language-plaintext highlighter-rouge">https://sites.google.com/chromium.org/driver/downloads?authuser=0</code></li>
  <li>安装selenium <code class="language-plaintext highlighter-rouge">pip install seleuim</code></li>
</ol>

<h3 id="执行代码">执行代码</h3>

<ul>
  <li>如果有一些网站需要登录，可以执行以下命令启动一个常驻浏览器，并且将用户信息写到指定路径</li>
</ul>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># 针对macos 上的Chrome 浏览器</span>
<span class="nb">export </span><span class="nv">PATH</span><span class="o">=</span><span class="s2">"/Applications/Google Chrome.app/Contents/MacOS:</span><span class="nv">$PATH</span><span class="s2">"</span>
Google<span class="se">\ </span>Chrome <span class="nt">--remote-debugging-port</span><span class="o">=</span>9222 <span class="nt">--user-data-dir</span><span class="o">=</span><span class="s2">"~/ChromeProfile"</span>
</code></pre></div></div>

<ul>
  <li>经过以上操作就会启动一个浏览器，后续使用代码可以控制该浏览器上的行为。</li>
</ul>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">selenium</span> <span class="kn">import</span> <span class="n">webdriver</span> <span class="c1"># selenium.__version__ = 4.8.0
</span><span class="kn">from</span> <span class="nn">selenium.webdriver.common.by</span> <span class="kn">import</span> <span class="n">By</span>
<span class="kn">from</span> <span class="nn">selenium.webdriver.common.action_chains</span> <span class="kn">import</span> <span class="n">ActionChains</span>
<span class="kn">import</span> <span class="nn">time</span>


<span class="s">'''
如果想要保持登录状态：
打开一个terminal执行以下命令：
export PATH="/Applications/Google Chrome.app/Contents/MacOS:$PATH"
Google\ Chrome --remote-debugging-port=9222 --user-data-dir="~/ChromeProfile"

启动chrome驻留之后再执行以下代码
'''</span>

<span class="c1"># 打开浏览器驱动
</span><span class="n">option</span> <span class="o">=</span> <span class="n">webdriver</span><span class="p">.</span><span class="n">ChromeOptions</span><span class="p">()</span>
<span class="n">option</span><span class="p">.</span><span class="n">add_experimental_option</span><span class="p">(</span><span class="s">"debuggerAddress"</span><span class="p">,</span> <span class="s">"127.0.0.1:9222"</span><span class="p">)</span>

<span class="c1"># 启动浏览器
</span><span class="n">driver</span> <span class="o">=</span> <span class="n">webdriver</span><span class="p">.</span><span class="n">Chrome</span><span class="p">(</span><span class="n">options</span> <span class="o">=</span> <span class="n">option</span><span class="p">)</span>
<span class="n">driver</span><span class="p">.</span><span class="n">implicitly_wait</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span>

<span class="k">class</span> <span class="nc">ServiceConfig</span><span class="p">():</span>

    <span class="c1"># 定义prepareWork函数，做准备工作
</span>    <span class="k">def</span> <span class="nf">prepareWork</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span><span class="n">url</span><span class="p">):</span>
        <span class="c1"># 打开百度首页
</span>        <span class="n">driver</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="n">url</span><span class="p">)</span>

        <span class="c1"># 查找搜索框元素
</span>        <span class="n">search_input</span> <span class="o">=</span> <span class="n">driver</span><span class="p">.</span><span class="n">find_element</span><span class="p">(</span><span class="n">By</span><span class="p">.</span><span class="n">XPATH</span><span class="p">,</span><span class="s">'//*[@id="root"]/div[1]/div[2]/div/div[2]/div[1]/div/div/div/input'</span><span class="p">)</span>

        <span class="c1"># 在搜索框中输入文本
</span>        <span class="n">search_input</span><span class="p">.</span><span class="n">send_keys</span><span class="p">(</span><span class="s">"人体工学椅子"</span><span class="p">)</span>
        <span class="c1"># time.sleep(2)
</span>
        <span class="c1"># 点击搜索按钮
</span>        <span class="n">search_button</span> <span class="o">=</span> <span class="n">driver</span><span class="p">.</span><span class="n">find_element</span><span class="p">(</span><span class="n">By</span><span class="p">.</span><span class="n">XPATH</span><span class="p">,</span><span class="s">'//*[@id="root"]/div[1]/div[2]/div/div[2]/div[1]/div/button/span'</span><span class="p">)</span>
        <span class="n">ActionChains</span><span class="p">(</span><span class="n">driver</span><span class="p">).</span><span class="n">move_to_element</span><span class="p">(</span><span class="n">search_button</span><span class="p">).</span><span class="n">click</span><span class="p">().</span><span class="n">perform</span><span class="p">()</span>
        <span class="c1"># time.sleep(2)
</span>
        <span class="n">setting_button</span> <span class="o">=</span> <span class="n">driver</span><span class="p">.</span><span class="n">find_element</span><span class="p">(</span><span class="n">By</span><span class="p">.</span><span class="n">XPATH</span><span class="p">,</span><span class="s">'//*[@id="root"]/div[2]/div/div[3]/div/div/div/div[1]/div[2]/p[1]'</span><span class="p">)</span>
        <span class="n">ActionChains</span><span class="p">(</span><span class="n">driver</span><span class="p">).</span><span class="n">move_to_element</span><span class="p">(</span><span class="n">setting_button</span><span class="p">).</span><span class="n">click</span><span class="p">().</span><span class="n">perform</span><span class="p">()</span>
        <span class="c1"># time.sleep(2)
</span>        <span class="n">windows</span> <span class="o">=</span> <span class="n">driver</span><span class="p">.</span><span class="n">window_handles</span>
        <span class="n">driver</span><span class="p">.</span><span class="n">switch_to</span><span class="p">.</span><span class="n">window</span><span class="p">(</span><span class="n">windows</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>

        <span class="n">sousuo_setting</span> <span class="o">=</span> <span class="n">driver</span><span class="p">.</span><span class="n">find_element</span><span class="p">(</span><span class="n">By</span><span class="p">.</span><span class="n">XPATH</span><span class="p">,</span><span class="s">'//*[@id="root"]/div[2]/div/div[3]/div[2]/div[5]/span[1]/button[1]/span'</span><span class="p">)</span>
        <span class="n">ActionChains</span><span class="p">(</span><span class="n">driver</span><span class="p">).</span><span class="n">move_to_element</span><span class="p">(</span><span class="n">sousuo_setting</span><span class="p">).</span><span class="n">click</span><span class="p">().</span><span class="n">perform</span><span class="p">()</span>



<span class="k">if</span> <span class="n">__name__</span> <span class="o">==</span> <span class="s">'__main__'</span><span class="p">:</span>
    <span class="n">url</span> <span class="o">=</span> <span class="s">'https://www.byte-mall.cn/'</span>
    <span class="n">sc</span> <span class="o">=</span> <span class="n">ServiceConfig</span><span class="p">()</span>
    <span class="n">sc</span><span class="p">.</span><span class="n">prepareWork</span><span class="p">(</span><span class="n">url</span><span class="p">)</span>
    <span class="n">time</span><span class="p">.</span><span class="n">sleep</span><span class="p">(</span><span class="mi">10000</span><span class="p">)</span>
    

</code></pre></div></div>

<ul>
  <li>具体其他复杂操作可以通过组合鼠标和键盘操作事件来实现</li>
</ul>

<div data-hk-top-pages="5"> 

</div>]]></content><author><name>李勇志 (Yongzhi Li)</name><email>yongzhili@pku.edu.cn</email></author><category term="blog" /><category term="python, tips" /><summary type="html"><![CDATA[]]></summary></entry><entry><title type="html">Python 启动http服务中文乱码问题</title><link href="https://liyongzhi.xyz/posts/2021/11/python-http-server-utf8/" rel="alternate" type="text/html" title="Python 启动http服务中文乱码问题" /><published>2021-11-25T00:00:00+08:00</published><updated>2021-11-25T00:00:00+08:00</updated><id>https://liyongzhi.xyz/posts/2021/11/blog</id><content type="html" xml:base="https://liyongzhi.xyz/posts/2021/11/python-http-server-utf8/"><![CDATA[<hr />

<hr />

<script async="" src="//busuanzi.ibruce.info/busuanzi/2.3/busuanzi.pure.mini.js"></script>

<div>
<div class="button01">
            <visited_a href="#" display:inline=""><span id="busuanzi_container_site_pv">你是第<span id="busuanzi_value_site_pv"></span>位访客~</span></visited_a>

      <visited_p class="top">٩(๑^o^๑)۶</visited_p>
      <visited_p class="bottom">Σ(っ °Д °;)っ被你发现了！</visited_p>
</div>
<img align="center" width="100" src="https://liyongzhi.xyz/images/static/take_me.gif" alt="" display:inline="" />
</div>

<hr />

<h2 id="python-启动httpserver服务中文乱码问题">Python 启动http.server服务中文乱码问题</h2>

<h3 id="正常启动方式">正常启动方式：</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">python</span> <span class="o">-</span><span class="n">m</span> <span class="n">http</span><span class="p">.</span><span class="n">server</span> <span class="mi">9999</span> <span class="c1">#在当前路径为根目录的情况下，启动http服务，端口为9999
</span></code></pre></div></div>
<p>但是当html文件为UTF-8编码时，可能会出现中文乱码问题。</p>

<h3 id="解决中文乱码问题">解决中文乱码问题：</h3>

<p>python2:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">python</span> <span class="o">-</span><span class="n">c</span> <span class="s">"import SimpleHTTPServer; m = SimpleHTTPServer.SimpleHTTPRequestHandler.extensions_map; m[''] = 'text/plain'; m.update(dict([(k, v + ';charset=UTF-8') for k, v in m.items()])); SimpleHTTPServer.test();"</span> <span class="mi">9999</span>  
</code></pre></div></div>

<p>python3:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">python3</span> <span class="o">-</span><span class="n">c</span> <span class="s">"from http.server import test, SimpleHTTPRequestHandler as RH; RH.extensions_map={k:v+';charset=UTF-8' for k,v in RH.extensions_map.items()}; test(RH,port=9999)"</span>  
</code></pre></div></div>

<h2 id="linux-批量杀掉进程脚本">Linux 批量杀掉进程脚本</h2>

<p>有时候一个命令会启动很多个进程，但是停掉之后还会有很多僵尸进程，以下脚本可以直接批量删除名称包含xxxxx的进程</p>

<div class="language-shell highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
ps <span class="nt">-aux</span> | <span class="nb">grep</span> <span class="s2">"xxxxxxx"</span> | <span class="nb">grep</span> <span class="nt">-v</span> <span class="nb">grep</span> | <span class="nb">awk</span> <span class="s1">'{print "kill -9  "$2}'</span> | sh

</code></pre></div></div>

<div data-hk-top-pages="5"> </div>]]></content><author><name>李勇志 (Yongzhi Li)</name><email>yongzhili@pku.edu.cn</email></author><category term="blog" /><category term="python, tips" /><summary type="html"><![CDATA[]]></summary></entry><entry><title type="html">二叉树的三种遍历（递归与非递归）</title><link href="https://liyongzhi.xyz/posts/2020/09/algorithm-binary-tree-Traversal/" rel="alternate" type="text/html" title="二叉树的三种遍历（递归与非递归）" /><published>2020-09-29T00:00:00+08:00</published><updated>2020-09-29T00:00:00+08:00</updated><id>https://liyongzhi.xyz/posts/2020/09/blog-post-algorithm-tree</id><content type="html" xml:base="https://liyongzhi.xyz/posts/2020/09/algorithm-binary-tree-Traversal/"><![CDATA[<hr />

<hr />

<script async="" src="//busuanzi.ibruce.info/busuanzi/2.3/busuanzi.pure.mini.js"></script>

<div>
<div class="button01">
            <visited_a href="#" display:inline=""><span id="busuanzi_container_site_pv">你是第<span id="busuanzi_value_site_pv"></span>位访客~</span></visited_a>

      <visited_p class="top">٩(๑^o^๑)۶</visited_p>
      <visited_p class="bottom">Σ(っ °Д °;)っ被你发现了！</visited_p>
</div>
<img align="center" width="100" src="https://liyongzhi.xyz/images/static/take_me.gif" alt="" display:inline="" />
</div>
<hr />

<h2 id="二叉树的三种遍历-递归和非递归实现">二叉树的三种遍历 递归和非递归实现</h2>

<ul>
  <li>三种遍历的递归实现</li>
</ul>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">'''
递归遍历,实现起来比较简单，主要就是调整root.val放入结果的时机
'''</span>
<span class="n">result</span><span class="o">=</span><span class="p">[]</span>
<span class="k">def</span> <span class="nf">preorder</span><span class="p">(</span><span class="n">root</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">root</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
        <span class="k">return</span>

    <span class="n">result</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">root</span><span class="p">.</span><span class="n">val</span><span class="p">)</span>
    <span class="n">preorder</span><span class="p">(</span><span class="n">root</span><span class="p">.</span><span class="n">left</span><span class="p">)</span>
    <span class="n">preorder</span><span class="p">(</span><span class="n">root</span><span class="p">.</span><span class="n">right</span><span class="p">)</span>
    <span class="k">return</span>

<span class="k">def</span> <span class="nf">inorder</span><span class="p">(</span><span class="n">root</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">root</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
        <span class="k">return</span>
    <span class="n">inorder</span><span class="p">(</span><span class="n">root</span><span class="p">.</span><span class="n">left</span><span class="p">)</span>
    <span class="n">result</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">root</span><span class="p">.</span><span class="n">val</span><span class="p">)</span>
    <span class="n">inorder</span><span class="p">(</span><span class="n">root</span><span class="p">.</span><span class="n">right</span><span class="p">)</span>
    <span class="k">return</span>


<span class="k">def</span> <span class="nf">postorder</span><span class="p">(</span><span class="n">root</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">root</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
        <span class="k">return</span>
    <span class="n">postorder</span><span class="p">(</span><span class="n">root</span><span class="p">.</span><span class="n">left</span><span class="p">)</span>
    <span class="n">postorder</span><span class="p">(</span><span class="n">root</span><span class="p">.</span><span class="n">right</span><span class="p">)</span>
    <span class="n">result</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">root</span><span class="p">.</span><span class="n">val</span><span class="p">)</span>
    <span class="k">return</span>
</code></pre></div></div>

<ul>
  <li>非递归实现</li>
</ul>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">'''
主要思想是使用栈来模拟递归调用的过程

其实先序和后序的代码非常相似，只是后续每次优先往右边深入，最后逆序一下结果

中序的结果是在while循环外面才保存到result中的，这个需要注意。
'''</span>


<span class="k">def</span> <span class="nf">pre_order</span><span class="p">(</span><span class="n">root</span><span class="p">):</span>
    <span class="n">result</span><span class="o">=</span><span class="p">[]</span>
    <span class="n">stack</span><span class="o">=</span><span class="p">[]</span>
    <span class="n">cur</span><span class="o">=</span><span class="n">root</span>
    <span class="k">while</span> <span class="n">cur</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">stack</span><span class="p">)</span><span class="o">&gt;</span><span class="mi">0</span><span class="p">:</span>
        <span class="k">while</span> <span class="n">cur</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
            <span class="n">stack</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cur</span><span class="p">)</span>
            <span class="n">result</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cur</span><span class="p">.</span><span class="n">val</span><span class="p">)</span>
            <span class="n">cur</span><span class="o">=</span><span class="n">cur</span><span class="p">.</span><span class="n">left</span>

        <span class="n">cur</span><span class="o">=</span><span class="n">stack</span><span class="p">.</span><span class="n">pop</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">cur</span><span class="o">=</span><span class="n">cur</span><span class="p">.</span><span class="n">right</span>
    <span class="k">return</span> <span class="n">result</span>

<span class="k">def</span> <span class="nf">in_order</span><span class="p">(</span><span class="n">root</span><span class="p">):</span>
    <span class="n">result</span><span class="o">=</span><span class="p">[]</span>
    <span class="n">stack</span><span class="o">=</span><span class="p">[]</span>
    <span class="n">cur</span><span class="o">=</span><span class="n">root</span>
    <span class="k">while</span> <span class="n">cur</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">stack</span><span class="p">)</span><span class="o">&gt;</span><span class="mi">0</span><span class="p">:</span>
        <span class="k">while</span> <span class="n">cur</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
            <span class="n">stack</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cur</span><span class="p">)</span>
            <span class="n">cur</span><span class="o">=</span><span class="n">cur</span><span class="p">.</span><span class="n">left</span>
        <span class="n">cur</span><span class="o">=</span><span class="n">stack</span><span class="p">.</span><span class="n">pop</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">result</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cur</span><span class="p">.</span><span class="n">val</span><span class="p">)</span>
        <span class="n">cur</span><span class="o">=</span><span class="n">cur</span><span class="p">.</span><span class="n">right</span>
    <span class="k">return</span> <span class="n">result</span>

<span class="k">def</span> <span class="nf">post_order</span><span class="p">(</span><span class="n">root</span><span class="p">):</span>
    <span class="n">result</span><span class="o">=</span><span class="p">[]</span>
    <span class="n">stack</span><span class="o">=</span><span class="p">[]</span>
    <span class="n">cur</span><span class="o">=</span><span class="n">root</span>
    <span class="k">while</span> <span class="n">cur</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">stack</span><span class="p">)</span><span class="o">&gt;</span><span class="mi">0</span><span class="p">:</span>
        <span class="k">while</span> <span class="n">cur</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
            <span class="n">result</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cur</span><span class="p">.</span><span class="n">val</span><span class="p">)</span>
            <span class="n">stack</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">cur</span><span class="p">.</span><span class="n">right</span><span class="p">)</span>
            <span class="n">cur</span><span class="o">=</span><span class="n">cur</span><span class="p">.</span><span class="n">right</span>

        <span class="n">cur</span><span class="o">=</span><span class="n">stack</span><span class="p">.</span><span class="n">pop</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">cur</span><span class="o">=</span><span class="n">cur</span><span class="p">.</span><span class="n">left</span>
    <span class="n">result</span><span class="p">.</span><span class="n">reverse</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">result</span>

</code></pre></div></div>

<div data-hk-top-pages="5"> </div>]]></content><author><name>李勇志 (Yongzhi Li)</name><email>yongzhili@pku.edu.cn</email></author><category term="Algorithm" /><category term="Algorithm, Binary Tree" /><summary type="html"><![CDATA[]]></summary></entry></feed>