Skip to content

[WAN 2.2] 2D Ring Attention with Custom Kernel#430

Open
eltsai wants to merge 1 commit into
mainfrom
2d_ring_custom_kernel
Open

[WAN 2.2] 2D Ring Attention with Custom Kernel#430
eltsai wants to merge 1 commit into
mainfrom
2d_ring_custom_kernel

Conversation

@eltsai

@eltsai eltsai commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Summary

In this PR, we are adding support for

  1. Ring attention with custom kernel (like Fix: Plumb vmem limit and bkv_compute_in to the custom kernel #416 we plumbed the vmem limits into the kernel)
  2. 2D ring attention (ring + ulysses) with custom kernel
  3. Bidirectional ring attention

(also formatting src/maxdiffusion/loaders/ltx2_lora_nnx_loader.py to remove formatting errors)

Results

Compared with previous best config

For dp2-cp4 on v7x-8, we see:

==================================================
  TIMING SUMMARY (2D Ring, U=2, R=2)
==================================================
  Load (checkpoint):     585.3s
  Compile:               144.3s
  Inference:             116.0s
  ────────────────────────────────────────
  Conditioning:            3.9s
    - VAE Encode:          0.0s
  Denoise Total:         109.8s
  VAE Decode:              2.2s
    - TPU Compute:         1.7s
    - Host Formatting:     0.5s
==================================================

which is a 11% speedup comparing best Ulysses config for the denoising step (9% e2e time):

==================================================
  TIMING SUMMARY (Ulysses)
==================================================
  Load (checkpoint):     141.2s
  Compile:               183.3s
  Inference:             126.9s
  ────────────────────────────────────────
  Conditioning:            2.2s
    - VAE Encode:          0.0s
  Denoise Total:         122.6s
  VAE Decode:              2.1s
    - TPU Compute:         1.7s
    - Host Formatting:     0.4s
==================================================

Best Configs

The 2D-ring (Ulysses×Ring) hybrid is optimal at cp4 and cp16, and ties pure Ulysses at cp8. Pure Ulysses never strictly wins.

  • At CP4, 2D-ring (U=2, R=2, BQ=9472) (2.74 s/step) beats the best pure Ulysses config (U=4, BQ=8448) (3.06 s/step) by ~10%.
  • At CP8, 2D-ring (U=4, R=2, BQ=9472) (1.595 sec/step) ties the best pure Ulysses config (U=8, BQ=8448) (1.590 sec/step).
  • At CP16 a 2D ring wins (and is the only option, pure Ulysses U=16 impossible because num_head=40). The best is (U=8, R=2, BQ=9472).

Compared with ring attention

  1. Using the custom kernel as base for ring attention gives a ~23% per-denoising-step speedup (cp=4: 4.85 → 3.74 s/step, BQ=4096).

  2. A tile-size (BQ) search found the optimal ring tile is BQ=9472 for R<=8, and BQ=4096 for R>=16. It's worth ~19% speedup (dp1-cp8 U=4: 1.98→1.60 s/step at BQ 4096→9472).

For more details, see the internal doc

Note on bi-directional (wrap-free) ring attention.

Because of TPU topology, when R_degree >= 8 we start to see slow downs due to communication. I did some PoC on bi-directional ring, and see improvements for some cases when R_degree=16. I think it would be more beneficial when we scale up, so I am including it in this PR (despite that it doesn't give us optimal results on CP4, CP8 and CP16).

@eltsai eltsai requested a review from entrpn as a code owner June 25, 2026 23:31
@github-actions

Copy link
Copy Markdown

@eltsai eltsai requested review from Perseus14, csgoogle and entrpn and removed request for entrpn June 25, 2026 23:31
@eltsai

eltsai commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator Author

Pasting more results for CP4, CP8 and CP16:
Screenshot 2026-06-25 at 4 46 33 PM
Screenshot 2026-06-25 at 4 46 49 PM
Screenshot 2026-06-25 at 4 46 58 PM

dp1 means that the CFG is set as False (the stats for dp1-cp16 is basically what we will see on dp2-cp16 on v7x-32 with CFG enabled).

@Perseus14 Perseus14 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the amazing work!

I have dropped a few comments for minor changes. PTAL!

Additionally in another PR, we should probably add a doc to the repo that talks about all the different attention kernels and where to use what.

Comment thread src/maxdiffusion/models/attention_flax.py Outdated
Comment thread src/maxdiffusion/models/attention_flax.py
Comment thread src/maxdiffusion/models/attention_flax.py
Comment thread src/maxdiffusion/models/attention_flax.py Outdated
@github-actions

Copy link
Copy Markdown

🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

csgoogle
csgoogle previously approved these changes Jun 26, 2026
@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @csgoogle, but I was unable to process your request. Please see the logs for more details.

from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_custom, ulysses_ring
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, tokamax_ring_custom, ulysses, ulysses_custom, ulysses_ring, ulysses_ring_custom, ulysses_ring_custom_bidir

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we comment the best flashblock sizes here itself, it would be easy for anypne to look into

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants