跳转至

FP8

INT8也能训练

前一篇博客 中,我们深入探讨了DeepSeek V3如何通过FP8实现高效训练,并成功克服了精度挑战。本文探讨另一个问题:如果用INT8代替FP8做训练,会发生什么?

INT8 量化

给定一个浮点数向量 \(x \in \mathbb{R}^n\),INT8量化的目标是将其映射到 [-128, 127] 的整数空间。这一过程需要确定缩放因子 \(\alpha\) 和零点偏移 \(\beta\),使得:

\[ x_q = round(\frac{x}{\alpha}) + \beta \]

其中 \(x_q\) 表示量化后的INT8值。缩放因子 \(\alpha\) 通常通过以下方式计算:

\[ \alpha = \frac{max(|x|)}{127} \]

这确保了量化后的值不会超出INT8的表示范围。而零点偏移 \(\beta\) 在对称量化场景下通常设置为0,在非对称量化时则需要根据数据分布来确定。

对于LLM训练场景,由于权重和激活值通常呈现对称分布,我们可以使用对称量化方案:

def symmetric_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, float]:
    alpha = x.abs().max() / 127.0  # 计算缩放因子
    x_q = torch.round(x / alpha)   # 量化
    x_q = torch.clamp(x_q, -128, 127)  # 截断
    return x_q, alpha

反量化操作则是将INT8值映射回浮点数空间:

\[ x_r = (x_q - \beta) \times \alpha \]

其中 \(x_r\) 是反量化后的浮点数值。在对称量化场景下,由于 \(\beta = 0\),反量化简化为:

def symmetric_dequantize(x_q: torch.Tensor, alpha: float) -> torch.Tensor:
    return x_q * alpha

与FP8的浮点量化不同,INT8采用均匀量化方案:

  • 优势区间:大值区域精度更高(固定量化步长)
  • 劣势区间:小值区域精度较低(相对误差更大)

这种特性使得INT8对数据分布形态更为敏感,需要针对性优化策略。

从DeepSeek V3看FP8训练的挑战

DeepSeek V3 的发布引起了对 FP8 训练的广泛关注,业界也出现了大量文章解析 How 的问题——DeepSeek 是怎么进行 FP8 训练的,与传统方案有哪些不同。但是目前鲜有文章对 Why 问题进行深入探讨,为何 DeepSeek 的方案能够取得成功。本文尝试对 FP8 训练所面临的挑战进行深入解析,并尝试猜测 DeepSeek 团队设计其 FP 方案的背后原理。(如果你对 INT8 训练感兴趣,可以参考本文的姊妹篇:INT8 训练

1. FP8 浮点格式

1.1 FP8 格式的历史

FP8 是一种遵循 IEEE 754 规范6的 8 位浮点数格式,由 Nvidia 在 2022 年发布的 H100 GPU 中首次引入。在此之前,Nvidia 硬件上浮点数格式的发展历程如下3

  • 2016 年 P100 GPU 首次引入 FP16 数据格式,直接开启了深度学习混合精度训练的技术路线;
  • 2017 年 V100 GPU 首次引入 Tensor Core, 用于加速 FP16 矩阵乘法运算;
  • 2020 年 A100 GPU 首次引入 TF32 数据格式,可通过 Tensor Core 加速;引入 bfloat16 数据格式,提供比 FP16 更宽的动态范围(当下 BF16 已经成为 LLM 训练的主流方案);
  • 2022 年 H100 GPU 首次引入 FP8 数据格式;

FP8 被 Nvidia 给予厚望,认为其成功的延续了 CEO 提出的 Huang’s Law4,即 10 年间 GPU 硬件算力提升 1000 倍。在过去的 10 年间,新型数值表达的引入了 16 倍算力提升,是诸多技术中贡献最大者,GPU 架构与复杂指令集紧随其后带来了 12.5 倍提升,而制程进步带来的收益非常有限,仅 2.5 倍5

1.2. 常见浮点数与 IEEE 754

IEEE 754 是目前广为使用的浮点数规范,定义了浮点数的 bitwise 表达与量化方式。浮点数的二进制表达分为三部分:

  • 符号位(sign)
  • 指数位(exponent)
  • 尾数位(mantissa)

常见的浮点数格式的二进制表达如下图所示:

block-beta
    columns 33
    FP32["fp32"]
    S1["S"]
    E1["E"]
    E2["E"]
    E3["E"]
    E4["E"]
    E5["E"]
    E6["E"]
    E7["E"]
    E8["E"]
    M1["M"]
    M2["M"]
    M3["M"]
    M4["M"]
    M5["M"]
    M6["M"]
    M7["M"]
    M8["M"]
    M9["M"]
    M10["M"]
    M11["M"]
    M12["M"]
    M13["M"]
    M14["M"]
    M15["M"]
    M16["M"]
    M17["M"]
    M18["M"]
    M19["M"]
    M20["M"]
    M21["M"]
    M22["M"]
    M23["M"]

    BF16["bf16"]
    SS1["S"]
    EE1["E"]
    EE2["E"]
    EE3["E"]
    EE4["E"]
    EE5["E"]
    EE6["E"]
    EE7["E"]
    EE8["E"]
    MM1["M"]
    MM2["M"]
    MM3["M"]
    MM4["M"]
    MM5["M"]
    MM6["M"]
    MM7["M"]
    space:16

    FP16["fp16"]
    space:3
    ss1["S"]
    ee1["E"]
    ee2["E"]
    ee3["E"]
    ee4["E"]
    ee5["E"]
    mm1["M"]
    mm2["M"]
    mm3["M"]
    mm4["M"]
    mm5["M"]
    mm6["M"]
    mm7["M"]
    mm8["M"]
    mm9["M"]
    mm10["M"]
    space:13

    E5M2["fp8"]
    space:3
    s1["S"]
    e1["E"]
    e2["E"]
    e3["E"]
    e4["E"]
    e5["E"]
    m1["M"]
    m2["M"]
    space:21

    E4M3["fp8"]
    space:4
    sss1["S"]
    eee1["E"]
    eee2["E"]
    eee3["E"]
    eee4["E"]
    mmm1["M"]
    mmm2["M"]
    mmm3["M"]
    space:21

    classDef name fill:#00000000, stroke:#00000000
    class FP32,BF16,FP16,E4M3,E5M2 name

    classDef sign fill:#EE0000, stroke:#00000000
    class S1,SS1,s1,ss1,sss1 sign

    classDef exp fill:#00EE00, stroke:#00000000
    class E1,E2,E3,E4,E5,E6,E7,E8 exp
    class EE1,EE2,EE3,EE4,EE5,EE6,EE7,EE8 exp
    class e1,e2,e3,e4,e5,e6,e7,e8 exp
    class ee1,ee2,ee3,ee4,ee5,ee6,ee7,ee8 exp
    class eee1,eee2,eee3,eee4,eee5,eee6,eee7,eee8 exp

1.3. FP8 有两种格式

随着浮点数位数从 16 位进一步降低到 8 位,动态范围不足的问题逐渐显现。因此 Nvidia、Arm 和 Intel 在 FP8 规范中设计了两种浮点数类型1:E4M3 和 E5M2

E4M3 E5M2
format(s/e/m) 1:4:3 1:5:2
Exponent bias 7 15
Infinities N/A S.11111.00
NaN S.1111.111 S.11111.{01,10,11}
Zeros S.0000.000 S.00000.00
Max normal S.1111.110 = \(1.75 \times 2^8\) = 448 S.11110.11 = \(1.75 \times 2^15\) = 57.344
Min normal S.0001.0000 = \(2^{-6}\) S.00001.00 = \(2^{-14}\)
Max subnorm S.0000.111 = \(0.875 \times 2^{-6}\) S.00000.11 = \(0.75\times 2^{-14}\)
Min subnorm S.0000.001 = \(2^{-9}\) S.00000.01 = $ 2^{-16}$

浮点数都会分配一些二进制表达来表示特殊值**NaN**和 \(\mathbb{\pm}\)Inf,IEEE 754 规范约定使用指数位全**1**的二进制表达来表示这些特殊值。对于 E4M3 格式来说,若严格遵循 IEEE 754 规范,会 8 个二进制表达。因此在定义 E4M3 规范时对这些二进制表达进行了额外开发,仅在指数位尾数位同时全为 1 时才表示 NaN,全为 0 的时候表示 \(\pm\)Inf

H100 的 Tensor Core 提供 3 倍 A100 FP16 性能,若启用 FP8 算力能够再次翻倍。