Adding Sequence Parallelism to Slime's FSDP Backend
Long-context RL training has a straightforward scaling problem: sequence length eats memory quadratically in attention, and FSDP alone can’t help because it only shards weights, not activations. Once you’re training Llama3-8B or Qwen3-7B with 16k+ context, a single GPU runs out of activation memory even with gradient checkpointing enabled. You need to split the sequence itself across GPUs.
This post documents the design for adding Sequence Parallelism (SP) to slime’s FSDP backend. The target is a 1-week MVP with two engineers, scoped tightly to Ring-Attention with a varlen API. I’ll walk through the route selection, system architecture, RL-specific coupling points, and the testing plan.
Why Sequence Parallelism, and Why Now
Slime’s FSDP backend already handles weight sharding and gradient accumulation well. But for the RL workloads we care about, the bottleneck has shifted. GRPO and PPO-style training on long-context tasks means generating and training on sequences of 8k-16k tokens or more. The memory pressure isn’t from parameters — it’s from the KV cache and attention activations during the training forward pass.
SP addresses this directly: partition the sequence across GPUs in a process group, have each GPU compute attention over its local chunk, and communicate KV blocks between neighbors. Memory per GPU drops roughly linearly with SP degree. For SP=2 on an 8-GPU node, each GPU handles half the sequence length during attention, cutting activation memory by around 35-50% depending on model architecture.
The prerequisites were already in place: FSDP weight cross-process updates (disaggregate mode), a working data packing pipeline, NVLink connectivity within nodes, and benchmark models (Llama3-8B, Qwen3-7B) that we run daily.
Ring-Attention vs. Ulysses: The Tradeoff
We evaluated two main approaches before committing.
Route A: Ring-Attention (selected)
Ring-Attention passes KV blocks around a ring of GPUs while each GPU computes local attention over its sequence chunk. We use the ring-flash-attention library, which provides a varlen-compatible API that slots directly into FlashAttention call sites.
The appeal is low code intrusion. The integration point is narrow: replace the flash_attn_varlen_func call in the attention layer with ring_flash_attn_varlen_func, pass in the SP process group, and handle RoPE position offsets. The rest of the model — MLP layers, normalization, embeddings — stays untouched. Varlen support means we keep our existing data packing pipeline, which is important because we pack multiple sequences per batch for throughput.
Risks: RoPE position encoding offsets need to account for the global sequence position, not just the local chunk. KV-cache precision under BF16 needs verification — ring communication introduces extra rounding. And we need to confirm that backward gradients through the ring are numerically stable.
Route B: DeepSpeed-Ulysses (spiked, deferred)
Ulysses takes a different approach: it uses All-to-All communication to redistribute the sequence across the head dimension before attention, computes standard attention, then All-to-Alls back. This has better asymptotic scaling for very long sequences because communication volume grows with head count rather than sequence length.
We spiked this and decided against it for the MVP. The FSDP integration is significantly more complex — Ulysses needs to restructure tensor layouts before and after every attention block, which conflicts with FSDP’s assumptions about tensor shapes during the forward pass. The communication pattern is also harder to overlap with compute. For our target of 8k-16k sequences on a single node, Ring-Attention’s bandwidth requirements are modest over NVLink, and the implementation simplicity wins.
If we later need to scale beyond 32k sequences or go cross-node, Ulysses becomes worth revisiting.
Fallback: llama3_varlen
A less intrusive variant of Ring-Attention that uses Llama3’s internal varlen attention implementation. Lower risk, slightly less flexible. We keep this as the llama3_varlen option in config, so users can fall back if the ring path has issues on their hardware.
System Architecture
The design touches five layers. I’ll walk through each.
1. Process Groups: SPGroupManager
SP groups are carved out of the existing data-parallel ranks. If you have 8 GPUs with DP=8 and set SP=2, you get 4 SP groups of 2 GPUs each, and effective DP becomes 4.
class SPGroupManager:
"""Builds SP groups within each DP shard."""
def __init__(self, sp_size: int, group_scope: str = "local_node"):
# Prefer same-node NVLink clustering
# E.g., 8 GPUs, SP=2 -> groups [(0,1), (2,3), (4,5), (6,7)]
self._sp_group = self._build_groups(sp_size, group_scope)
def get_sp_group(self) -> dist.ProcessGroup: ...
def get_sp_rank(self) -> int: ...
def get_sp_world_size(self) -> int: ...
The group_scope: local_node default ensures SP groups stay within a single node where NVLink bandwidth is high (~900 GB/s bidirectional on A100/H100 NVLink). Cross-node SP over InfiniBand is technically possible but the bandwidth penalty is steep and not worth it for our target sequence lengths.
2. Data Layer: VarLenCollator
The existing collator packs multiple sequences into a single batch element with cu_seqlens marking boundaries. For SP, we extend this with two additions:
- Length-aware bucketing. Group sequences by similar length before packing to minimize padding waste when splitting across SP ranks.
-
pack_to_multiplealignment. Pad the total packed length to a multiple of 128 (orsp_size * 128), so each SP rank gets an evenly divisible chunk. This avoids ragged splits that would complicate the ring communication.
Each SP rank receives a contiguous slice of the packed sequence along with the cu_seqlens offsets adjusted to local coordinates.
3. Attention Replacement: attn_patch.py
The core integration point. A function substitute_flash_attn_with_ring() walks the model’s attention modules and swaps the FlashAttention call for the ring variant:
def substitute_flash_attn_with_ring(model, sp_group, heads_k_stride=1):
"""
Replace flash_attn_varlen_func with ring_flash_attn_varlen_func
in all attention layers.
Maintains:
- RoPE offsets (global position = local_pos + sp_rank * chunk_len)
- cu_seqlens in local coordinates
- BF16/FP16 compatibility
"""
for module in model.modules():
if hasattr(module, '_flash_attn_forward'):
module._flash_attn_forward = _make_ring_wrapper(
module._flash_attn_forward, sp_group, heads_k_stride
)
The RoPE offset is the trickiest part. Each SP rank applies rotary embeddings using global position indices, not local ones. We compute position_ids = local_position_ids + sp_rank * local_seq_len before the RoPE application. Getting this wrong produces silently wrong attention patterns — the model trains but learns garbage.
4. Forward/Backward and Loss Aggregation
Most of the forward pass runs identically to the non-SP case. SP only affects attention blocks. For loss computation, there’s a subtlety:
- Token-level losses (cross-entropy, KL): Each SP rank holds logits for its local token chunk. For standard next-token prediction, this is fine — each rank computes loss on its local tokens and we
AllReduceto get the mean. - Sequence-level aggregation: When we need full-sequence logits (e.g., for certain reward computations), we
AllGatherlogits across the SP group. This is expensive for large vocabularies, so we do it on-demand rather than by default.
allgather_logits: on_demand # only when loss function requires full sequence
5. Checkpointing
SP partitions activations during the forward pass but does not partition weights. Checkpointing uses the standard FSDP shard format. This means you can save a checkpoint with SP=2 and load it with SP=1 (or any other SP degree) without any conversion. This was a deliberate design choice — we didn’t want SP degree to be baked into the checkpoint format.
Config Design and CLI Interface
The full config block:
sp:
enable: true
impl: ring # ring | llama3_varlen | ulysses
size: 2 # SP degree, must divide world_size
group_scope: local_node # local_node | global
heads_k_stride: 1
allgather_logits: on_demand # always | on_demand | never
impl controls which attention backend to use. size is the SP degree. group_scope controls whether SP groups are constrained to a single node. heads_k_stride is passed through to the ring attention kernel for K-head striding in GQA models. allgather_logits controls when full-sequence logit gathering happens.
When sp.enable is false, the entire SP code path is a no-op. No process groups are created, no attention patching happens, no collator changes. This is important for debugging — you can always fall back to the exact same training setup as before.
Launch command:
torchrun --nproc_per_node=8 train_fsdp.py \
--config slime/configs/fsdp_sp_example.yaml \
--sp.enable true --sp.impl ring --sp.size 2
RL Coupling Considerations
SP interacts with RL training in a few non-obvious ways.
GAE computation. Generalized Advantage Estimation runs over the full trajectory. With SP, each rank holds advantages for its local token chunk. GAE is inherently sequential (it’s a backward scan over timesteps), so we compute it locally on each rank’s shard and AllGather at episode boundaries to reconstruct full-trajectory advantages. In practice, for per-token GAE the local computation is correct as long as the value estimates at shard boundaries are consistent, which they are because the value head sees the same hidden states via the ring attention.
KL divergence against the reference model. The standard approach requires logits from both the policy and reference model over the full sequence. With SP, gathering full-sequence logits from both models is expensive. We sidestep this by using reference-free KL estimation (the Schulman approximation), which only needs the policy’s own log-probabilities and can be computed locally per SP rank. This avoids any full-sequence gather for the reference model.
Rollout generation. SGLang stays in TP mode during rollout. SP is only active during the training forward and backward passes. This separation is clean — the rollout server doesn’t know or care about SP. Sequences are gathered into full form before being sent to the training loop’s collator, which then re-splits them for SP.
Test Plan and Risk Mitigation
Test Matrix
Unit tests:
- SP group construction: verify correct rank assignments for various SP sizes and node topologies
- Ring attention patch: compare outputs against vanilla FlashAttention on identical inputs, check numerical closeness (atol=1e-3 for BF16)
- VarLenCollator: verify
cu_seqlensalignment andpack_to_multiplepadding logic
End-to-end tests:
- 8xA100 or 8xH100 node, Llama3-8B, sequence lengths 8k and 16k
- Compare training loss curves: SP=1 (baseline) vs SP=2 vs SP=4
- Verify checkpoint save/load across different SP degrees
Benchmark targets:
| Config | Seq Len | Target Memory | Target Throughput |
|---|---|---|---|
| FSDP baseline | 8k | baseline | baseline |
| FSDP + SP=2 (ring) | 8k | >= 35% reduction | >= baseline |
| FSDP + SP=2 (ring) | 16k | trainable | within budget |
| FSDP + SP=4 (ring) | 16k | >= 50% reduction | trainable |
Stability criterion: 3+ hours with no OOM or NaN.
Key Risks
RoPE position misalignment. The most likely source of silent correctness bugs. Mitigation: a dedicated test that compares attention outputs with and without SP on a known input, checking that the position-dependent attention patterns match.
BF16 precision drift. Ring communication introduces extra cast/reduce operations. Mitigation: run a 500-step training comparison between SP=1 and SP=2, checking that loss curves diverge by less than 1%.
Cross-node bandwidth. If someone sets group_scope: global with SP=2 where ranks span nodes, the ring communication goes over InfiniBand instead of NVLink. This could be 10-20x slower. Mitigation: warn loudly in logs when cross-node SP is detected, default to local_node scope.
RL metric aggregation. Reward and advantage statistics need consistent reduction across SP and DP groups. Mitigation: assert that aggregated metrics match between SP=1 and SP=2 runs on the same data.
Timeline
The 7-day plan splits work between two people. Person A owns the distributed infrastructure (process groups, data layer, training loop integration). Person B owns the kernel and model layer (attention patching, RoPE offsets, numerical validation). Days 1-2 are parallel groundwork, days 3-4 are integration, day 5 is end-to-end testing, days 6-7 are benchmarking and documentation. The llama3_varlen fallback gets implemented as a stretch goal on day 6 if ring attention is stable.
This is a tightly scoped MVP. We’re deliberately not touching Ulysses, not changing the checkpoint format, and not adding SP-aware gradient checkpointing (though that’s a natural follow-up). The goal is to get Ring-Attention SP working reliably with FSDP on a single node, validate the memory savings, and ship it as an opt-in config flag. Everything else can come later.