LLM

环境与沙箱系统:从轨迹到智能体训练基础设施

本文为原创文章,版权归作者所有。未经许可,禁止转载。

大模型训练最早给人的印象是“喂文本”。预训练阶段,模型从大规模语料中学习语言、知识和世界规律;SFT 阶段,模型从人工标注的指令数据中学习如何回答问题。这个视角在早期足够有效,因为模型要学的主要是输入到输出的映射:给定一段 prompt,生成一段符合人类偏好的 response

但当模型开始具备推理、工具调用、代码修改、网页浏览和长程任务执行能力后,训练对象发生了变化。模型最终说了什么不再是最重要的问题。更值得关心的是它 ** 如何一步一步到达结果 **——查了哪些信息,调用了什么工具,执行了哪些命令,观察到什么反馈,在哪一步修正了计划,又是如何判断任务已经完成的。

这些过程合在一起,就是轨迹。

大模型时代的存储:从生成到引用

本文为原创文章,版权归作者所有。未经许可,禁止转载。

大模型推理看起来是 next-token prediction,但真实系统里的推理很少只是“生成”。提示模板、会话历史、长期记忆、检索结果、工具返回和数据库记录,都会进入推理链路。模型并不是在真空中猜下一个词,而是在“继续生成”和“引用已有信息”之间不断切换。

这也是大模型时代重新需要讨论存储的原因。存储不再只是“把东西放起来”,而是一套围绕 可引用性 建立的机制:让信息在合适的时刻、以合适的形态、用可控的成本被再次使用。问题也随之从“数据放在哪里”,变成了“引用路径如何被组织”。

Agent Search 把这个变化放大了。一个 Agent 很少只问一次;它会围绕同一个目标反复追问、调用工具、验证证据、修正计划。真正被反复触碰的往往不是整个知识库,而是当前任务附近的一小片局部工作集。全量数据仍然可以很大,也可以沉在对象存储里;靠近 Agent 的,则是少量语义相关的数据块、证据正文、工具结果和验证材料。存储系统的任务,就从“保存全量知识”变成了“让这些小工作集能被低成本地再次读出来”。

Agent Search 将全局知识库压缩成局部工作集

1:全局数据仍然可以很大,Agent 真正反复访问的是当前任务附近的一小片局部工作集。

Probing 分布式探针开发随笔(三:分布式训练的 Profiling

在前两篇系列文章中,介绍了 Probing 分布式探针的核心理念与技术探索,包括其应对 ABI 兼容性挑战的动态注入机制,以及基于 DataFusion 构建的可扩展查询引擎。虽然仍处于技术原型阶段,但也确实看到了实现一个“完美”工具的可能性。最近在解决某千卡训练项目时,浪费了大量时间在实验、抓数据与复现等工作上,越来越感觉到传统工具的限制,也越来越急迫地需要将 Probing 推向生产。

传统 Profiler 的困境

Profiling 是性能优化工程师最为主要的优化手段,为了分析性能我们有形形色色的 Profiler 工具。大到 Intel VTune Nvidia Insight 这种系列工具,有着完备的分析工具与可视化手段,很多问题都能一目了然;小到perf top这样”简陋”的调用栈采样工具,得边看边猜整个系统的行为。但是这些工具有一个共同的问题:他们都是单机工具,并不能很好的解决分布式系统中的性能问题。比如,PyTorch 提供了torch.profiler,一个强大的内置性能分析工具,使用也极为方便:

from torch.profiler import profile, CPU, CUDA

with profile(activities=[CPU, CUDA]) as prof:
    for step in range(steps):
        train_one_step()
        prof.step()

torch.profiler能够抓取 PyTorch 中的算子执行与显存分配行为,并且可以通过 Tensorboard 对结果进行可视化。但对于千卡规模的分布式训练,torch.profiler还远远不够用:

  1. 性能开销问题:profiler 会显著影响性能,导致 Profiling 结果不准确;
  2. 数据爆炸:单卡长生上 G 的数据,千卡需要上 T 存储;
  3. 缺乏协调:各节点数据相互独立,难以进行关联分析,特别是引入模型并行之后;

思路转变:从 Timeline 到统计方法

Timeline 困境

单节点的性能分析中,Timeline 技术备受追捧。原因无他:直观,每个阶段的执行,开销的资源与消耗的时间都可以在一个时间轴上精确展示出来。但是 Timeline 数据庞大,并且借助浏览器渲染时速度也欠佳。几个 G Timeline 数据很快会让你的笔记本成为一个小火炉。Timeline 也存在明显的局限性:

  • 一般只能分析单个节点单个 Step
  • 多节点多 Step 的数据,看不过来(虽然有些人在此会很倔强
  • 每个 Step 都有差异,导致难以给出结论(可以给定性结论,但结论复现存在难度
  • 难以捕捉这个系统的随机性与不确定性:
  • 单个节点单个 Step 是确定的,但是一千个节点的同一个 Step,充满了随机性;
  • Timeline 无法刻画出整个系统性能层面的统计特性,比如耗时的 99 线;
  • 忽略了负载不均衡现象:
  • 在经典的 Dense LLM 中,因为模型并行会导致每个节点实际负载各不相同;
  • 在流行的 MoE LLM 中,专家路由也会导致计算负载不均衡问题;

上述这些问题难以在 Timeline 框架下靠修修补补来解决:

  1. 在由上千节点与数千线缆构成的复杂计算集群中,节点与节点间互联必然存在随机性与不确定性,这正是分布式系统的核心挑战。因此,分布式系统的性能分析需要从单节点上精准 timeline 的个体样本方法 ** 转向能够描述随机性与整体特性的统计方法 **

  2. 单节点的 Profiling 数据量非常巨大,却又缺乏有效的数据压缩与处理手段。单机尚能撑住,扩展到千卡集群就直接原地爆炸。分布式 Profiling 必然需要 ** 转向现代化的数据基建,全面拥抱分布式的数据存储与分析技术 **

  3. 分布式系统中很难保证时间一致与时间精度,多机 timeline 很难进行对齐,也很难可视化分析(可以想象下一千张卡的 timelime 等你去看如何在不依赖精准时间戳的情况下进行数据关联分析、识别性能异常节点 (Stragglers),也成为分布式系统性能分析的关键挑战

分布式系统的统计思想

分布式系统最常见的性能分析范式是分布式 Tracing,如 OpenTelemetry 这类系统已在微服务领域取得了成功,这些系统的核心理念可以适配到分布式训练环境:

  1. 借鉴 Span 概念:将训练过程分解自顶向下的、嵌套的 span。前向传播、反向传播作为顶级 Span,每个 layer 的计算作为子 Span。这种层次化的视图只需要明确层次关系,而无须精确的时间戳对齐。

  2. 优化采样策略:不同于 timeline 的全量采样,分布式 profiling 可以通过设计采样策略来控制开销:
    - 结构化采样:根据模型结构进行采样而非完全随机采样;
    - 分布式采样:将采样操作分布到不同的节点,降低每个节点的采样量;

  3. 分析效率:模型训练中每个 span 内的计算量与通信量可以精确计算,结合 span 计时即可分析每一段时间的硬件利用率与瓶颈,而无须像 timeline 那样精需要精准的时间信息。

  4. 统计视角替代精确时间线:关注分布特性(均值、中位数、百分位数)而非单个精确时间点,使问题分析更符合分布式系统的随机性特质。

不过分布式训练的通信模式是集合通信而非调用树,可以尝试为训练系统单独设计一套分布式 Profiling 方案。

基于探针的分布式 Profiling

训练系统分布式 Profiling 需要克服的主要困难有两个:

  1. 没有配套的数据系统:训练过程中的数据大多数没有业务价值,不会配套专门的数据处理与存储系统;
  2. 数据量庞大:每个 GPU 在一个训练 Step 内就会产生数万个事件,而总数据量会随着 step 树与节点数增长而快速爆炸;

Probing 的解决方案是:本地化存储数据 + 分布式查询分析,将数据存储和分析的压力分散到每个节点上。以下是一个简单的示意图,用于说明理想情况下 probing 如何工作:

---
title: Probing 分布式 Profiling 架构
---
graph TD
    subgraph "控制平面 (用户)"
        UI[Web UI]
        CLI[命令行]
        API[SQL查询+HTTP协议]
        UI & CLI --> API
    end

    subgraph "分布式训练集群"
        direction LR
        subgraph "Node 1 (Rank 0)"
            P1[训练进程 Rank 0]
            PR1[Probe]
            H1[采集Hooks e.g., PyTorch]
            P1 --> H1 -- 本地数据 --> PR1
        end
        subgraph "Node 2 (Rank 1)"
            P2[训练进程 Rank 1]
            PR2[Probe]
            H2[采集Hooks e.g., PyTorch]
            P2 --> H2 -- 本地数据 --> PR2
        end
        subgraph "Node N (Rank N-1)"
            PN[训练进程 Rank N-1]
            PRN[Probe]
            HN[采集Hooks e.g., PyTorch]
            PN --> HN -- 本地数据 --> PRN
        end
    end

    API -- SQL查询 --> PR1;
    PR1 -- 分布式查询协调 --> PR2;
    PR1 -- 分布式查询协调 --> PRN;
    PR2 -- 本地查询/聚合 --> PR2;
    PRN -- 本地查询/聚合 --> PRN;
    PR2 -- 部分结果 --> PR1;
    PRN -- 部分结果 --> PR1;
    PR1 -- 最终聚合 --> API;

    style P1 fill:#f9f,stroke:#333
    style P2 fill:#f9f,stroke:#333
    style PN fill:#f9f,stroke:#333
    style PR1 fill:#bfb,stroke:#333
    style PR2 fill:#bfb,stroke:#333
    style PRN fill:#bfb,stroke:#333
    style H1 fill:#ccf,stroke:#333
    style H2 fill:#ccf,stroke:#333
    style HN fill:#ccf,stroke:#333

在这个架构下,可以借助分布式查询系统,将过滤、采样与聚合操作下推到每个节点去执行,并结合良好设计的采样机制与策略来平衡性能分析的精度与开销。接下来是在这个架构下设计数据采集、存储和分析的链路

采集链路

基于钩子的数据采集

虽然修改代码加日志是最直观的数据采集手段,也日志往往过于随意、缺乏设计,为后续的分析与使用带来困难。不修改代码采集数据就需要对代码进行自动插桩。好在 PyTorch 提供了钩子(Hooks)机制,能够”不侵入”代码的情况下完成插桩。

from torch.optim.optimizer import register_optimizer_step_post_hook

register_optimizer_step_post_hook(optimizer_step_post_hook)

register_optimizer_step_post_hook 帮我们向 torch 注册一个钩子函数,在每个 Optimzier 完成step()调用后执行。这个插桩时机极为关键:

  1. 模型已完成构建,可获取完整模型定义
  2. 前向传播、反向传播与优化器都已完成预热

接下来,借助 Python 的垃圾回收 (GC) 机制与反射能力来捕获进程中的模型结构:

def get_toplevel_module():
    import gc

    import torch

    objs = [obj for obj in gc.get_objects() if isinstance(obj, torch.nn.Module)]
    is_child = set()
    for obj in objs:
        for child in obj.children():
            is_child.add(id(child))
    return [obj for obj in objs if id(obj) not in is_child]

通过gc模块我们可以获得当前进程中的全部 Python 对象列表,再通过反射调用isinstance(obj, torch.nn.Module)找出全部torch.nn.Module对象。最后再根据 module 之间的父子关系来发现顶层 Module

获取顶层 Module 后,我们可以注册完整的前向 / 反向传播钩子链,完成接下来的插桩:

  1. Module.register_forward_pre_hook - 前向传播开始前
  2. Module.register_forward_hook - 前向传播完成后
  3. Module.register_full_backward_pre_hook - 反向传播开始前
  4. Module.register_full_backward_hook - 反向传播完成后
  5. Optimizer.register_step_pre_hook - 优化器步骤开始前
  6. Optimizer.register_step_post_hook - 优化器步骤完成后

这些钩子构成了训练过程中的完整监控链,允许我们精确测量模型各组件的执行性能。

结构化采样

考虑到 PyTorch 模型包含大量嵌套子模块,对每个模块都执行计时操作会带来显著性能开销。随机采样虽然能够降低插桩的开销,但需要等待较长时间才能保证采样充分。这里我们引入一种结构化采样方法来加速性能数据的采集:

  1. span 分解:将模型执行分解为一系列 span,每个 module 的前向和反向传播分别构成独立 span
  2. 层次化排序:按照嵌套关系对 span 进行排序
    • 粗粒度 span(如整个模型的前向传播)排序靠前
    • 细粒度 span(如单个卷积层的操作)排序靠后
  3. 自适应采样:从粗到细逐步采样
    • 命中采样时,记录当前 span 计时,并移至下一个 span
    • 未命中采样时,跳过计算以减少开销

这种结构化采样确保每个训练步骤只对一个特定粒度的 span 进行采样,使模型性能分析由粗到细逐步进行,在控制开销的同时提供全面性能视图。

基于 CUDA Event 的精确计时

GPU 上异步执行的计时通常通过 CUDA Event 来实现。CUDA Event 能保证在 CUDA Stream 上的执行顺序,并且是测量 GPU 操作时间的最准确方式。一个 CUDA Event 的生命周期包括以下几个阶段:

  1. 创建 (Create):通过 torch.cuda.Event() CUDA 原生 API 创建 Event 对象
  2. 记录 (Record):通过 event.record() Event 标记到特定 CUDA Stream 的当前位置
  3. 同步 (Synchronize):通过 event.synchronize() 等待 Event 标记的操作完成
  4. 查询 (Query):通过 event.query() 非阻塞地检查 Event 是否完成
  5. 计时 (Elapsed Time):通过 start_event.elapsed_time(end_event) 计算两个 Event 之间的时间差

在实际应用中,同步 (Synchronize) 操作会导致 GPU 等待并强制 Stream 清空,可能显著影响性能。为解决这一问题,我们采用延迟计时 (Delayed Timing) 策略,将时间读取推迟到优化器执行完成后进行。这种方法有效降低了计时操作对训练性能的干扰,特别适合分布式训练环境。

基于统计的性能 / 故障分析方法

在大规模分布式训练环境中,我们面临的不仅是如何采集数据,更重要的是如何有效利用这些数据发现并解决问题。Probing 采用统计分析方法,将分散在各节点的性能数据转化为可操作的洞察。

分布式训练中的常见性能问题

在实践中,分布式训练的性能问题通常表现为以下几种典型模式:

  1. 慢节点 (Straggler) 问题:个别节点显著慢于集群平均水平,拖慢整体训练进度
  2. 负载不均衡:计算或内存负载在节点间分布不均,导致资源利用率低下
  3. 通信瓶颈:节点间数据交换速度不足,制约训练效率提升
  4. 异常波动:性能指标在时间维度上出现突发性异常
  5. 集群分层:性能根据硬件配置或网络拓扑自然分层,形成性能梯队利用统计数据定位问题

节点性能差异分析

通过简单 SQL 查询,我们可以快速识别集群中的异常节点:

-- 查找前向传播耗时异常的节点
SELECT 
    rank, 
    AVG(duration_ms) as avg_forward_time,
    COUNT(*) as sample_count,
    (AVG(duration_ms) - 
     (SELECT AVG(duration_ms) FROM torch_traces WHERE operation='forward')) 
     / (SELECT STDDEV(duration_ms) FROM torch_traces WHERE operation='forward') 
     as z_score
FROM python.torch_traces
WHERE operation = 'forward' AND step_id BETWEEN 100 AND 200
GROUP BY rank
HAVING z_score > 2.0  -- 标准差超过2倍的视为异常
ORDER BY avg_forward_time DESC;

这种查询允许我们立即发现性能显著偏离集群平均水平的节点,而无需手动检查每个节点的 timeline

层次性能分布图

分布式训练中,模型的不同组件在不同节点上的性能表现极具研究价值。Probing 通过层次性能分布图直观展示这种多维度性能数据,帮助工程师快速定位瓶颈。通过 Probing 可以采集如下格式的数据:

ts: 事件时间戳
node:节点名称
module:模块名称
stage:阶段名称,比如forward或者backward
mem_allocated: 已经分配的显存
mem_cached: 已经缓存的显存
duration:时间开销

通过对采集的结构化数据进行多维度聚合与可视化,我们可以构建如下分析图表:

-- 分析每个模型层在不同节点上的性能分布
SELECT
    module,
    node,
    AVG(duration_ms) as avg_duration,
    PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY duration_ms) as median,
    PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY duration_ms) as p95,
    COUNT(*) as samples
FROM python.torch_traces
WHERE operation = 'forward' AND step_id BETWEEN 1000 AND 2000
GROUP BY module, node
ORDER BY module, avg_duration DESC;

这种查询能够生成深度学习模型中每个组件在集群不同节点上的性能热力图,通过这种热力图,我们可以立即观察到:

  • 水平方向:同一节点上不同模型层的相对性能
  • 垂直方向:同一模型层在不同节点上的性能差异
  • 热点区域:特定节点 - 组件组合的性能异常

在一次实际分析中,我们通过层次性能分布图发现了某 DNN 模型中有趣的性能模式:

  • 组件级差异:Attention 层在所有节点上都比其他层耗时更长(水平模式)
  • 节点级差异:特定的 4 个节点在处理卷积层时显著慢于其他节点(垂直模式)
  • 交互效应:某些节点仅在处理特定类型的层时出现性能下降(局部热点)

时间维度的性能演变分析

分布式训练的性能问题常常随时间动态变化。通过跟踪关键指标的时间序列,我们可以发现潜在问题:

-- 分析训练过程中的性能趋势
SELECT 
    FLOOR(step_id / 50) * 50 as step_bucket,  -- 按50步为单位分桶
    AVG(duration_ms) as avg_duration,
    STDDEV(duration_ms) / AVG(duration_ms) as cv  -- 变异系数
FROM torch_traces
WHERE operation = 'forward' 
GROUP BY step_bucket
ORDER BY step_bucket;

通过这种分析,我们可以发现:

  • 训练初期的预热效应
  • 性能随时间的逐渐劣化
  • 可能的内存泄漏或资源竞争问题
  • 周期性波动(如系统 GC 或后台任务影响)

分布式系统的层次化分析

在大型集群中,仅分析个体节点往往不够。Probing 支持按网络拓扑、硬件型号等进行分组分析:

 # 按网络拓扑分组分析通信性能
rack_perf = probe.sql("""
    SELECT 
        CASE 
            WHEN src_rank / 8 = dst_rank / 8 THEN 'same_node'
            WHEN src_rank / 32 = dst_rank / 32 THEN 'same_rack'
            ELSE 'cross_rack'
        END as topology,
        AVG(bytes_per_sec) as avg_bandwidth,
        COUNT(*) as sample_count
    FROM comm_events
    GROUP BY topology
""").fetchall()

for row in rack_perf:
    print(f"{row.topology}: {row.avg_bandwidth/1e9:.2f} GB/s ({row.sample_count} samples)")
# 输出:
# same_node: 87.32 GB/s (12453 samples)
# same_rack: 23.76 GB/s (8721 samples) 
# cross_rack: 11.89 GB/s (5432 samples)

这种分析揭示了网络拓扑对通信性能的影响,启发我们优化通信算法和数据分片策略以减少跨机架通信。

Probing 分布式探针开发随笔(二:探针机制

引言

在前一篇文章中,我们介绍了探针思路的设计理念,以及 Probing 分布式探针系统的整体架构。本文将详细介绍 Probing 的探针机制,包括探针的动态注入与运行时加载,以及如何规避 C/C++ 常见的 ABI 兼容性问题。

为何探针的动态注入能力尤为重要?因为故障和性能问题的发生总是不期而至,我们无法保证每次出现问题时都能提前部署探针。因此,任何需要提前部署的工具都迫使工程师必须”复现”问题才能进行分析,这无疑大大增加了诊断难度和时间成本。而分布式场景下,复现的成本与难度更是倍增,毕竟难以预留千卡或者万卡资源来复现问题。

异构计算则是另一个让复现问题变得更加困难的因素。在异构计算中,程序状态不再单纯地保存在 CPU 的内存中,而是同时分布在 GPUTPU 等计算单元的内存中。这些计算设备的内存中不存在类似调用栈这种结构化数据,我们无法简单地通过 dump 调用栈来捕获故障时刻的状态,而是需要 dump 整个计算设备的内存内容。对于常见的单机八卡配置,完整 dump 一次设备内存需要占用 640GB 的存储空间,这无疑是一个巨大的挑战。而管理这些数据的元数据通常存储在 Python 解释器中,这意味着必须开发一个跨设备、跨语言的调试工具,才能实现完整的故障诊断。

探针则是尝试另一种解决问题的思路:

  • 通过动态注入,即可实现在任意条件下调试与诊断;
  • 借助探针动态操作目标进程的 Python 解释器,利用其自然可以实现跨语言、跨设备的调试能力;

探针机制

探针注入的关键在于在目标进程的代码逻辑之外,额外向进程植入一段代码。常见的代码植入方式有两种:

  1. LD_PRELOAD方法:通过LD_PRELOAD环境变量,可以让 ld.so 在加载目标进程的时候,优先加载指定的动态链接库从而实现代码植入。这种方法的优点是简单易用,但是只能在进程启动时生效,无法在进程运行时动态注入;
  2. ptrace方法:通过ptrace系统调用,可以在进程运行时动态修改进程的内存,从而实现代码植入。这种方法的优点是可以在进程运行时动态注入,但是需要对目标进程有一定的权限,且对目标进程的性能影响较大。

本文重点介绍ptrace方法的实现,LD_PRELOAD方法介绍的文章很多,本文不再赘述。

ptrace系统调用介绍

ptrace是一个 Linux 系统调用,用于监控和控制另一个进程。ptrace的调用方式如下:

#include <sys/ptrace.h>

long ptrace(enum __ptrace_request op, pid_t pid,
            void *addr, void *data);

ptrace 提供了一种控制目标进程执行的方法,它可以让调试器与目标进程进行交互,从而实现调试功能。__ptrace_request 常用的取值如下:

  • PTRACE_ATTACH: 附加到目标进程,使其成为当前进程的 tracee
  • PTRACE_INTERRUPT: 暂停目标 tracee
  • PTRACE_CONT: 让目标进程继续执行;
  • PTRACE_DETACH: 释放目标 tracee
  • PTRACE_GETREGS/PTRACE_SETREGS: 读写目标进程寄存器;
  • PTRACE_PEEKDATA/PTRACE_POKEDATA: 读写目标进程内存,一次一个 WORD
  • /proc/<pid>/mem: 大块读写内存;

常见的一个 debugger 的工作流程如下:

  1. attach 到目标进程;
  2. 通过读写目标进程 TEXT 段插入断点;
  3. 恢复目标进程执行,并用waitpid等待目标进程断点暂停;
  4. 等到目标进程暂停,通过读写内存查看信息;

探针注入流程

这里参考了 https://github.com/Artemis21/ptrace-inject 项目,进行了一些修改。注入流程如下:

  1. 通过PTRACE_ATTACH附加到目标进程;

Rust 中可以通过pete库对ptrace的封装来使用ptrace系统调用:

let mut tracer = pete::Ptracer::new();
tracer
    .attach((&proc).into())
    .context("failed to attach to given process")?;
log::trace!("Attached to process with PID {}", proc);
  1. 写入 shellcode 到目标进程的内存中;

首先找到一处合适的内存地址,具有执行权限,可以写入 shellcode。这里我们通过读取目标进程的内存映射信息,找到一个具有执行权限的内存区域:

/// Find a suitable address to inject the shellcode into.
pub(crate) fn find_executable_space(&self) -> Result<u64> {
    log::trace!("Finding executable space in target process");
    self.0
        .maps() // 读取 /proc/<pid>/maps 文件,获取进程的内存映射信息
        .context("failed to read process memory maps to find executable region")?
        .into_iter()
        .find(|m| m.perms.contains(process::MMPermissions::EXECUTE))
        .map(|m| m.address.0)
        .ok_or_else(|| {
            anyhow::anyhow!("could not find an executable region in the target process")
        })
}

上述代码通过读取/proc/<pid>/maps文件,获取进程的内存映射信息,找到一个具有执行权限的内存区域。接下来我们先保存这个内存区域的内容,然后写入 shellcode

// 打开 /proc/<pid>/mem 文件,供后续读写内存使用
let mem = fs::OpenOptions::new().read(true).write(true)
    .open("/proc/<pid>/mem")?;

// 根据偏移量,读取目标进程的内存
let len = mem.read_at(data, addr)?;

// 将shellcode写入目标进程的内存
let len = mem.write_at(shellcode, addr)?;

其中data是一个[u8; 1024]大小的数组,用于保存原内存区域的内容;shellcode是我们要写入的 shellcode,内容如下

/// The x64 shellcode that will be injected into the tracee.
const SHELLCODE: [u8; 6] = [
    // Nop slide to make up for the fact that jumping is imprecise.
    0x90, 0x90, // nop; nop
    // The tracer does most of the work by putting the arguments into the
    // relevant registers, and the function pointer into `r9`.
    0x41, 0xff, 0xd1, // call r9
    // Trap so that the tracer can set up the next call.
    0xcc, // int3
];

shellcode 主要由三部分组成:

  • 两个nop指令,避免跳转时的不精确性带来问题;
  • 一个call r9指令,调用r9寄存器中的函数指针,此处调用会遵循 X86_64 下的标准调用协议,通过寄存器传参;
  • 一个int3指令,触发中断,控制流程回到 tracer
  1. 通过设置寄存器调用目标函数:

tracer 中设置寄存器,让目标进程调用函数:

self.tracee
    .set_registers(pete::Registers {
        rip: shellcode_address,
        // shellcode会通过r9寄存器调用函数
        r9: fn_address,
        // 根据x86-64 ABI要求,将函数入参传递到寄存器中
        rdi,
        rsi,
        // 根据x86-64 ABI要求,确保栈指针对齐到16字节
        rsp: self.saved_registers.rsp & !0xf,
        ..self.saved_registers
    })

函数fn_address是我们要调用的函数在目标进程中的虚拟地址,rdirsi是根据 x86-64 调用约定传递的前两个函数参数,rsp是栈指针,必须对齐到 16 字节以符合 ABI 要求。特别注意,fn_address必须是目标进程地址空间中的有效地址,否则会触发SIGSEGV信号导致进程崩溃。而目标进程的地址是不固定的,我们需要通过函数相对 so 文件的偏移量来计算。首先分别获取libc.so tracer tracee 中的地址,可以通过/proc/<pid>/maps文件获取每个 so 映射到内存的地址。再根据函数在 tracer 中的地址计算函数在libc.so中的偏移量。最后在 tracee 中根据libc.so的地址与函数偏移量计算目标函数在 tracee 中的真实地址,即可根据该地址进行调用。

获取函数真实地址的代码比较冗长,感兴趣的话可以参考仓库中的源码

通过上述步骤,我们可以在 tracee 中调用dlopen函数,加载动态链接库,实现动态注入。

探针实现

ptrace只是帮助我们实现了探针的动态注入,而真正的探针逻辑还需要我们自己实现。根据前文所述,借助ptrace可以让目标进程调用dlopen来加载动态链接库。而在动态库加载的过程中,会读取 ELF(Executable and Linkable Format) 文件中的.init_array段,该段中存放了一系列初始化函数的地址。C/C++ 编译器一般支持__attribute__((constructor))属性,可以将函数注册到.init_array段中。

__attribute__((constructor)) void my_init() {
    // 初始化代码
}

Rust 中可以通过#[ctor]宏实现类似的功能:

#[ctor]
fn my_init() {
    // 初始化代码
}

Probing 的注入框架不仅支持其内置探针模块,还支持用户自定义的探针库,提供了极高的扩展性。关于探针的具体设计细节,我们将在后续文章中深入探讨。

ABI 兼容性

传统的 C/C++ 项目经常受 ABI(Application Binary Interface)兼容性的困扰。常见的 ABI 兼容性问题有两类:

  1. glibc 中函数的版本问题:为了保证 ABI 的兼容性,glibc 中的函数会有多个版本,比如malloc函数就有malloc@GLIBC_2.2.5malloc@GLIBC_2.3等多个版本。而动态链接库在链接时会在当前 glibc 中选取一个最新的版本,这就导致了在较新的系统下编译的 so 文件在较旧的系统上无法运行;
  2. C++ ABI 问题:C++ ABI 问题主要由于最近几年 C++ 标准的更新较快,导致 libstdc++ 库的 ABI 不断变化。其中最为常见的一种错误是std::string类型在 C++11 标准中引入了短字符串优化(SSO)机制,导致std::string的内存布局发生了变化。而在 C++11 之前编译的 so 文件在 C++11 标准下运行时,会出现内存布局不一致的问题;

Probing 主要通过两种方式解决 ABI 兼容性问题:纯静态链接与 zigbuild

纯静态链接

静态链接是解决 ABI 兼容性的一种经典方法,通过将所有依赖库代码打包到一个 so 文件中,并在链接阶段完成所有符号的解析,从而避免了运行时出现 ABI 问题。Rust 在构建 so 文件的时候默认使用纯静态链接,能够很大程度上避免 C/C++ 项目中的 ABI 兼容性问题。

zigbuild

Zig 是一种新兴的系统级编程语言,内置完整的交叉编译工具链,可针对不同 glibc 版本生成二进制文件:

zig cc main.c -o main -Dtarget=arch64-linux-gnu.2.31

这使得使用 Zig 工具链构建的 so 文件可以通过指定低版本的 glibc 来增加 so 文件的兼容性。

cargo-zigbuild Rust 构建工具cargo的一个扩展,可以在编译时指定 glibc 的版本,并借助 Zig 的工具链完成 so 文件的链接。

cargo zigbuild --target x86_64-unknown-linux-gnu.2.17

打包发布

前文已经讨论了探针的动态注入与 ABI 兼容性问题,两者都尽最大的可能让 Probing 可以在任意环境下直接运行,而无须额外的配置。接下来我们将讨论 Probing 的打包发布问题,这是让 Probing 真正成为一个通用的工具的关键。

二进制工具发布通常有三种渠道:

  1. 发布源码:将源码发布到 github 等代码托管平台,用户可以自行编译;但往往构建一个复杂项目的环境是非常困难的,尤其是在分布式环境下;
  2. 发行版包管理器:将二进制工具打包成 rpmdeb 等包,发布到发行版的包管理器中,用户可以通过包管理器安装;但是不同发行版的包管理器不同,维护成本较高;并且同一个发行版的不同版本需要维护不同的包;
  3. pip/conda 等第三方发布平台:将二进制工具打包成 pip/conda 包,发布到第三方平台,用户可以通过 pip/conda 安装;但是这种方式往往需要用户安装额外的包管理器,不够方便;

不过对于 AI 领域的工具来说,Python 是必不可免的,因此基于 Python 包管理工具 pip 或者 conda 来发布 Probing 是一个不错的选择。

不同于一般的 python 包,Probing 是一个以 Rust 为主要开发语言的工具,因此并不适合使用 setup.py 等传统方式来构建 python 包。这里我们选择直接使用脚本来打包whl :

def write_wheel_file(filename, contents):
    with WheelFile(filename, "w") as wheel:
        for member_info, member_source in contents.items():
            ...
    return filename


def write_wheel(out_dir, *, name, version, tag, metadata, description, contents):
    name_snake = name.replace("-", "_")
    wheel_name = f"{name_snake}-{version}-{tag}.whl"
    dist_info = f"{name_snake}-{version}.dist-info"
    return write_wheel_file(
        os.path.join(out_dir, wheel_name),
        {
            **contents,
            f"{dist_info}/METADATA": make_message(...),
            f"{dist_info}/WHEEL": make_message(...),
        },
    )


def write_probing_wheel(
    out_dir, *, platform="manylinux_2_12_x86_64.manylinux2010_x86_64"
):
    ...

    for name, path in {
        "probing": "target/x86_64-unknown-linux-gnu/release/probing",
        "libprobing.so": "target/x86_64-unknown-linux-gnu/release/libprobing.so",
    }.items():
        zip_info = ZipInfo(f"probing-{metadata["version"]}.data/scripts/{name}")
        zip_info.external_attr = (stat.S_IFREG | 0o755) << 16
        with open(path, "rb") as f:
            contents[zip_info] = f.read()
    ...
    return write_wheel(
        out_dir,
        name="probing",
        version=metadata["version"],
        tag=f"py3-none-{platform}",
        metadata={...},
        description=description,
        contents=contents,
    )


def main():
    wheel_path = write_probing_wheel("dist/")
    with open(wheel_path, "rb") as wheel:
        print(f"  {wheel_path}")
        print(f"    {hashlib.sha256(wheel.read()).hexdigest()}")


if __name__ == "__main__":
    main()

该脚本主要使用wheel包中的WheelFile类来构建whl文件,并将构建出来的二进制写入到probing-{version}.data/scripts目录下。此外需要提供METADATAWHEEL文件,分别用于描述包的元信息和 wheel 的版本信息。

总结

本文主要讨论了 Probing 的核心机制——探针注入,并讨论了如何将这一机制变成一个通用工具,让其能使用到复杂多样的生产环境中,能够快速发布给尽可能多的用户。所有这些设计都是为了 Probing 的一个核心设计理念:解决问题时,应直接面对根本问题,避免陷入工具配置、环境搭建等元问题的循环中。或者可以认为这一设计理念是马斯克第一性原则的一种体现,缩短解决问题的路径,提高解决问题的效率。

在下一篇文章中将会介绍探针 so 的设计与实现。

Probing 分布式探针开发随笔(一:背景与设计理念

分布式训练系统的泥潭

在过去半年多的时间里,我一直在支持千卡规模的 LLM 分布式训练。坦白讲,千卡训练的过程并不愉快,尤其是在性能调优和故障排查方面。在这个过程中,我们遇到了一系列棘手的问题:训练无法正常启动、通信库突然 hang 住、节点性能不及预期、训练速度不稳定等等。这些问题不仅严重影响了训练效率,还大幅增加了调试的复杂度,导致我们不得不花费大量时间和精力在性能调优和故障排查上。

有人可能会说,千卡(乃至万卡)规模的稳定性问题在大厂内部已经解决得相当好了。然而,那些耗费无数人力堆砌出来的系统,往往只是在这些大厂已有的复杂基础设施上打补丁,解决眼前可见的问题,而且很多时候仅仅是在处理问题的表象。大规模分布式异构训练真正需要的是类似 Hadoop、Spark、Kubernetes TensorFlow 这样具有前瞻性的系统设计,能够解决问题的本质,并提供解决问题的框架,而不仅仅是一些堆砌在特定基础设施上、不具备任何迁移性的”补丁”。我们需要一种更加系统化、可扩展的方法来应对这些挑战。

Probing——分布式探针系统的原型探索

在解决问题的过程中,我一直思索自己到底需要什么。我需要一种能够在任何时刻动态启用,无需预先部署或插桩,在生产任务中以极低性能开销持续运行,实现实时监控与故障追溯的诊断工具。我需要一种不仅支持单机诊断,还能无缝覆盖分布式训练环境,无论集群规模如何,都能确保数据采集与故障分析的一致性的诊断工具。我需要一种能够从硬件层面的诊断数据、芯片互联状态,到框架、系统和模型各层数据的全面采集,构建完整的闭环监控系统的诊断工具。而现有的种种工具,要么需要侵入式的代码修改和预先部署,要么会严重影响性能,要么只能关注单机,无法覆盖分布式环境,要么只能关注单一维度,无法实现综合分析。

基于自己的需求,我开始尝试设计一种“探针”系统:

  • 可以在任意时刻通过动态注入的方式启用,无需预先部署或插桩;
  • 运行开销极低或者无开销,可以在生产任务中持续收集性能数据和故障数据;
  • “寄生”在目标进程中,具有相同的内存地址空间与权限,进而实现观测和调试;
  • 支持分布式,更好地覆盖大规模分布式训练环境;

这套探针系统大致用法如下:

$ probing <pid> inject # 注入探针
$ probing <pid> eval "print('Hello, Probing!')" # 在目标进程中执行代码
$ probing <pid> query "SHOW tables" # 查看可用数据
$ probing <pid> query "SELECT * FROM process.envs" # 查询进程环境变量

probing通过query命令提供 SQL 查询接口,并在这一接口下标准化了不同类型的数据,包括进程状态、硬件性能指标、网络状态、文件系统状态等,使用户无须单独学习每种数据的获取和分析方式。另一方面,SQL 查询也提供和 AI 接入能力,用户可以借助 AI 生成查询与分析语句,实现自动化的性能分析与故障诊断。后续也会直接扩展 SQL 支持分布是查询,实现对整个集群的性能分析与故障诊断。

在接下来的一系列文章里,我将详细介绍 Probing 的设计与实现,包括探针机制、数据采集、分析方法等方面。希望这个探索能够为大规模分布式训练的性能分析与故障诊断提供一些启发。以下是接下来需要进行讨论的内容:

  1. 如何实现探针的动态注入与运行时加载,如何规避 C/C++ 常见的 ABI 兼容性问题;
  2. 如何实现高频数据的采集和存储,如何实现数据的压缩和优化;
  3. 如何避免跨节点时钟漂移带来的事件时间不一致问题;

Training Dynamics Outlier——LLM 模型训练过程中的数值特性分析

Training Dynamics 是一个未被严格定义的词,泛指模型训练过程中观测到的各种现象、变化和行为规律。我们可以从 loss、泛化 loss、梯度大小以及等等表现来观察模型内部的演变机制,并总结出类似涌现现象(Emergency、Scaling Law、Double Decent Gradient Pathologies 等现象。

特别地,权重矩阵与激活值的动态演变 (Dynamics) 会直接影响数值表达范围,进而决定硬件计算精度选择与量化误差控制策略。本文聚焦 Transformer 架构中关键组件的数值动态特性,重点分析其对低精度训练与推理的工程影响。

权重与激活的数值演变特征

这里先给出权重与梯度的直观数值变化,帮助直观理解训练过程。下图取自某开源仓库 1,展示了权重数值的直方分布随训练进行的变化情况:



可以发现,各个 block FFN 部分权重从随机初始化的高斯分布,开始时较为稳定;在 2000 step 左右开始剧烈变化;随后整体分布再次稳定下来。权重整体保留了高斯分布,但是存在一些不是非常大的 outlier

接下来再看一下激活值的分布变化,在训练开始后,残差激活值迅速从高斯分布转变为逻辑分布(Logistic Distribution,并且出现较大的 outlier



这种激活上的 outlier 会对模型量化过程产生极大的影响,因此诸如 AWQ 等量化方法会重点关注激活中的 outlier 情况,以保证模型推理时的精度。

梯度分布的变化趋势与权重类似,训练过程也未出现较大的 outlier,说明梯度本身也具备较好的稳定性,存在低精度计算和存储的可能性。



INT8 也能训练

前一篇博客 中,我们深入探讨了 DeepSeek V3 如何通过 FP8 实现高效训练,并成功克服了精度挑战。本文探讨另一个问题:如果用 INT8 代替 FP8 做训练,会发生什么?

INT8 量化

给定一个浮点数向量 \(x \in \mathbb{R}^n\)INT8 量化的目标是将其映射到 [-128, 127] 的整数空间。这一过程需要确定缩放因子 \(\alpha\) 和零点偏移 \(\beta\),使得:

\[ x_q = round(\frac{x}{\alpha}) + \beta \]

其中 \(x_q\) 表示量化后的 INT8 值。缩放因子 \(\alpha\) 通常通过以下方式计算:

\[ \alpha = \frac{max(|x|)}{127} \]

这确保了量化后的值不会超出 INT8 的表示范围。而零点偏移 \(\beta\) 在对称量化场景下通常设置为 0,在非对称量化时则需要根据数据分布来确定。

对于 LLM 训练场景,由于权重和激活值通常呈现对称分布,我们可以使用对称量化方案:

def symmetric_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, float]:
    alpha = x.abs().max() / 127.0  # 计算缩放因子
    x_q = torch.round(x / alpha)   # 量化
    x_q = torch.clamp(x_q, -128, 127)  # 截断
    return x_q, alpha

反量化操作则是将 INT8 值映射回浮点数空间:

\[ x_r = (x_q - \beta) \times \alpha \]

其中 \(x_r\) 是反量化后的浮点数值。在对称量化场景下,由于 \(\beta = 0\),反量化简化为:

def symmetric_dequantize(x_q: torch.Tensor, alpha: float) -> torch.Tensor:
    return x_q * alpha

FP8 的浮点量化不同,INT8 采用均匀量化方案:

  • 优势区间:大值区域精度更高(固定量化步长)
  • 劣势区间:小值区域精度较低(相对误差更大)

这种特性使得 INT8 对数据分布形态更为敏感,需要针对性优化策略。

DeepSeek V3 FP8 训练的挑战

DeepSeek V3 的发布引起了对 FP8 训练的广泛关注,业界也出现了大量文章解析 How 的问题——DeepSeek 是怎么进行 FP8 训练的,与传统方案有哪些不同。但是目前鲜有文章对 Why 问题进行深入探讨,为何 DeepSeek 的方案能够取得成功。本文尝试对 FP8 训练所面临的挑战进行深入解析,并尝试猜测 DeepSeek 团队设计其 FP 方案的背后原理(如果你对 INT8 训练感兴趣,可以参考本文的姊妹篇:INT8 训练

1. FP8 浮点格式

1.1 FP8 格式的历史

FP8 是一种遵循 IEEE 754 规范 1 8 位浮点数格式,由 Nvidia 2022 年发布的 H100 GPU 中首次引入。在此之前,Nvidia 硬件上浮点数格式的发展历程如下 2

  • 2016 P100 GPU 首次引入 FP16 数据格式,直接开启了深度学习混合精度训练的技术路线;
  • 2017 V100 GPU 首次引入 Tensor Core, 用于加速 FP16 矩阵乘法运算;
  • 2020 A100 GPU 首次引入 TF32 数据格式,可通过 Tensor Core 加速;引入 bfloat16 数据格式,提供比 FP16 更宽的动态范围(当下 BF16 已经成为 LLM 训练的主流方案
  • 2022 H100 GPU 首次引入 FP8 数据格式;

FP8 Nvidia 给予厚望,认为其成功的延续了 CEO 提出的 Huang’s Law3,即 10 年间 GPU 硬件算力提升 1000 倍。在过去的 10 年间,新型数值表达的引入了 16 倍算力提升,是诸多技术中贡献最大者,GPU 架构与复杂指令集紧随其后带来了 12.5 倍提升,而制程进步带来的收益非常有限,仅 2.5 4

1.2. 常见浮点数与 IEEE 754

IEEE 754 是目前广为使用的浮点数规范,定义了浮点数的 bitwise 表达与量化方式。浮点数的二进制表达分为三部分:

  • 符号位(sign)
  • 指数位(exponent)
  • 尾数位(mantissa)

常见的浮点数格式的二进制表达如下图所示:

block-beta
    columns 33
    FP32["fp32"]
    S1["S"]
    E1["E"]
    E2["E"]
    E3["E"]
    E4["E"]
    E5["E"]
    E6["E"]
    E7["E"]
    E8["E"]
    M1["M"]
    M2["M"]
    M3["M"]
    M4["M"]
    M5["M"]
    M6["M"]
    M7["M"]
    M8["M"]
    M9["M"]
    M10["M"]
    M11["M"]
    M12["M"]
    M13["M"]
    M14["M"]
    M15["M"]
    M16["M"]
    M17["M"]
    M18["M"]
    M19["M"]
    M20["M"]
    M21["M"]
    M22["M"]
    M23["M"]

    BF16["bf16"]
    SS1["S"]
    EE1["E"]
    EE2["E"]
    EE3["E"]
    EE4["E"]
    EE5["E"]
    EE6["E"]
    EE7["E"]
    EE8["E"]
    MM1["M"]
    MM2["M"]
    MM3["M"]
    MM4["M"]
    MM5["M"]
    MM6["M"]
    MM7["M"]
    space:16

    FP16["fp16"]
    space:3
    ss1["S"]
    ee1["E"]
    ee2["E"]
    ee3["E"]
    ee4["E"]
    ee5["E"]
    mm1["M"]
    mm2["M"]
    mm3["M"]
    mm4["M"]
    mm5["M"]
    mm6["M"]
    mm7["M"]
    mm8["M"]
    mm9["M"]
    mm10["M"]
    space:13

    E5M2["fp8"]
    space:3
    s1["S"]
    e1["E"]
    e2["E"]
    e3["E"]
    e4["E"]
    e5["E"]
    m1["M"]
    m2["M"]
    space:21

    E4M3["fp8"]
    space:4
    sss1["S"]
    eee1["E"]
    eee2["E"]
    eee3["E"]
    eee4["E"]
    mmm1["M"]
    mmm2["M"]
    mmm3["M"]
    space:21

    classDef name fill:#00000000, stroke:#00000000
    class FP32,BF16,FP16,E4M3,E5M2 name

    classDef sign fill:#EE0000, stroke:#00000000
    class S1,SS1,s1,ss1,sss1 sign

    classDef exp fill:#00EE00, stroke:#00000000
    class E1,E2,E3,E4,E5,E6,E7,E8 exp
    class EE1,EE2,EE3,EE4,EE5,EE6,EE7,EE8 exp
    class e1,e2,e3,e4,e5,e6,e7,e8 exp
    class ee1,ee2,ee3,ee4,ee5,ee6,ee7,ee8 exp
    class eee1,eee2,eee3,eee4,eee5,eee6,eee7,eee8 exp

1.3. FP8 有两种格式

随着浮点数位数从 16 位进一步降低到 8 位,动态范围不足的问题逐渐显现。因此 NvidiaArm Intel FP8 规范中设计了两种浮点数类型 5E4M3 E5M2

E4M3 E5M2
format(s/e/m) 1:4:3 1:5:2
Exponent bias 7 15
Infinities N/A S.11111.00
NaN S.1111.111 S.11111.{01,10,11}
Zeros S.0000.000 S.00000.00
Max normal S.1111.110 = \(1.75 \times 2^8\) = 448 S.11110.11 = \(1.75 \times 2^15\) = 57.344
Min normal S.0001.0000 = \(2^{-6}\) S.00001.00 = \(2^{-14}\)
Max subnorm S.0000.111 = \(0.875 \times 2^{-6}\) S.00000.11 = \(0.75\times 2^{-14}\)
Min subnorm S.0000.001 = \(2^{-9}\) S.00000.01 = $ 2^{-16}$

浮点数都会分配一些二进制表达来表示特殊值 **NaN** \(\mathbb{\pm}\)InfIEEE 754 规范约定使用指数位全 **1** 的二进制表达来表示这些特殊值。对于 E4M3 格式来说,若严格遵循 IEEE 754 规范,会 8 个二进制表达。因此在定义 E4M3 规范时对这些二进制表达进行了额外开发,仅在指数位尾数位同时全为 1 时才表示 NaN,全为 0 的时候表示 \(\pm\)Inf

H100 Tensor Core 提供 3 A100 FP16 性能,若启用 FP8 算力能够再次翻倍。

从强化学习到 DeepSeek R1

1. 什么是强化学习 (RL, Reinforcement Learning)

传统的机器学习,包括深度学习,其本质是数学性的,严格遵守函数的数学定义:对于给定输入,产生确定的输出

\[F(x) = y\]

随着输入 \(x\) 和输出 \(y\) 的不同,这一范式可以适配各种不同的任务,比如:

  • \(x\) 是图像,\(y\) 是类别,那么 \(F\) 就是 Resnet 这种图像模型;
  • \(x\) 是语音信号,\(y\) 是文字,那么 \(F\) 就是一个语音识别模型;
  • \(x\) 是文本输入,\(y\) 是文本输出,那么 \(F\) 就是时下火热的大语言模型;

强化学习(Reinforcement Learning)的本质上则是哲学性的,它探讨三个核心问题:

  • 我是谁?一个 Agent
  • 我在哪?处于某个 State
  • 到哪里去?采取一个 Action

如果站在上帝视角去观测这个 Agent,我们还会发现:

  • Agent 处在一个环境中(Environment)
  • Agent 有一个用来策略(Policy)告诉我该采取什么动作(Action)
  • 每执行一个动作(Action,环境都会给我反馈 (Reward)

以上就是强化学习中的主要概念。

alt text

2. 如何进行强化学习

这里以一个迷宫问题为例,介绍如何进行强化学习:

迷宫:(S: Start, E: End, W: Wall)

block-beta
  columns 3
  S1["S1(S)"] S2 S3["S3(W)"]
  S4 S5 S6
  S7["S7(W)"] S8 S9["S9(E)"]

这个迷宫就是一个 Environment。我们放置一个机器人在开始处(Start,让机器人自动学习如何走迷宫的策略(Policy。这个策略可以记成 \(\pi(s)\rightarrow a, s \in [1-9], a \in [上, 下, 左, 右]\)。开始时机器人对于迷宫一无所知,所以 \(\pi(s)会随机输出一个方向\)

关于分布式模型并行的分正式评论

关于 Data Parallel(DP、Tensor Parallel(TP)和 Pipeline Parallel(PP)等分布式并行策略,与 MegatronDeepSpeed FSDP 等实现的一些深入研究与讨论。分布式并行训练主要解决两类问题:

  1. 模型分片:模型大小远超单节点存储上来,需要多节点分片存储和计算担;
  2. 并行训练:提高单位时间内的算力密度,进而降低整体训练时间;
    分布式并行训练几乎总是会引入昂贵的成本,比如增加了昂贵的多节点通信、引入了额外的多机稳定性问题、以及额外的开发与调试成本等,因此我们应该尽量避免引入分布式并行训练。而不得不引入分布式训练的场景中,也应充分考虑通信开销,尽量降低并行的规模。

3D 模型并行

根据切分维度的不同,并行策略主要分为如下几类:

  1. Data Parallel(DP:将数据切分成 N 份,每个 instance 采用完全相同的配置,在计算梯度后通过 all reduce 全局同步梯度,并分别更新;
  2. Tensor Parallel(TP:将每个 tensor 切分成 N 份,在矩阵乘法等计算时进行同步;也称为横切
  3. Pipeline Parallel (PP:将模型按执行前后顺序切分多分(通常按 layer 切分,并根据顺序依次执行;
  4. Zero Redundancy Optimizer(ZeRO:同样将 tensor 切分成 N 份,但是在前向与后向计算时在每个分布式节点重建原始 Tensor
  5. Sequence Parallel(SP:在超长序列上进行训练时,将计算切分至多个节点;

dp_tp_pp.png