This commit is contained in:
Binyang Li
2026-06-16 01:14:14 +00:00
parent 29847e6179
commit 74dce951bf

View File

@@ -79,7 +79,6 @@ class MoECommunicatorConfig:
# Streams and overlap
comm_stream: Optional[torch.cuda.Stream] = None
enable_overlap: bool = False
overlap_granularity: str = "none" # "none", "tensor", "expert", "block"
```
The constructor can accept either a config object or keyword arguments:
@@ -200,23 +199,24 @@ class MoECommunicator:
) -> torch.Tensor:
...
def dispatch_start(...) -> DispatchRequest:
def dispatch_async(..., overlap_config: Optional[CommOverlapConfig] = None) -> DispatchRequest:
...
def combine_start(...) -> CombineRequest:
def combine_async(..., overlap_config: Optional[CommOverlapConfig] = None) -> CombineRequest:
...
def make_combine_overlap(
def create_overlap_config(
self,
handle: DispatchHandle,
op: str, # "dispatch" or "combine"
*,
granularity: str = "block",
) -> CombineOverlap:
handle: Optional[DispatchHandle] = None,
level: str = "op", # "op" or "block"
) -> CommOverlapConfig:
...
```
The blocking `dispatch` and `combine` methods should be the default path. The
`*_start` methods and `make_combine_overlap` are optional advanced APIs for
`*_async` methods and `create_overlap_config` are optional advanced APIs for
communication/computation overlap.
## High-level API
@@ -263,8 +263,61 @@ class DispatchOutput:
class DispatchHandle:
"""Opaque handle returned by dispatch and consumed by combine."""
@dataclass
class CommOverlapConfig:
op: str # "dispatch" or "combine"
level: str = "op" # "op" or "block"
stream: Optional[torch.cuda.Stream] = None
wait_event: Optional[torch.cuda.Event] = None
signal: Optional[torch.Tensor] = None
num_comm_sms: Optional[int] = None
block_m: Optional[int] = None
block_ready_value: Optional[int] = None
```
`create_overlap_config` creates optional overlap configuration for async
dispatch/combine calls.
```python
dispatch_overlap_config = moe_comm.create_overlap_config(op="dispatch")
combine_overlap_config = moe_comm.create_overlap_config(op="combine", handle=handle)
```
Operation-level overlap does not require `create_overlap_config`; `dispatch_async`
and `combine_async` can use their default stream/event behavior. Use
`create_overlap_config` when the caller wants explicit stream/event/SM settings
or block-level combine overlap.
For block-level MLP/combine overlap, the config includes the combine-side wait
protocol and the device signal that an overlap-capable MLP backend must publish:
```python
combine_overlap_config = moe_comm.create_overlap_config(
op="combine",
handle=handle,
level="block",
)
```
`op="dispatch", level="block"` is not part of the first version. Dispatch
overlap is operation-level only.
`CommOverlapConfig` fields:
| Field | Purpose |
|---|---|
| `op` | `"dispatch"` or `"combine"` |
| `level` | `"op"` or `"block"` |
| `stream` | Optional communication stream |
| `wait_event` | Optional event the communication op waits on before starting |
| `signal` | Device tensor written by MLP and waited on by combine for block overlap |
| `num_comm_sms` | Optional SM budget for communication |
| `block_m` | Rows per block for block overlap |
| `block_ready_value` | Signal value that marks one block as ready for combine |
`DispatchHandle` should store the metadata needed to reverse dispatch:
- source rank and source token index,
@@ -480,8 +533,13 @@ expert_output = mlp(dispatch_out.tokens, dispatch_out.num_tokens_per_expert)
output = moe_comm.combine(expert_output, handle)
```
For overlap, expose an optional split-phase API rather than adding many flags to
the default path.
For overlap, expose two optional APIs rather than adding many flags to the
default path:
| API | Purpose |
|---|---|
| `dispatch_async` / `combine_async` | Coarse-grained async launch and wait |
| `create_overlap_config(..., level="block")` | Fine-grained block notify between MLP down-GEMM and combine |
### Coarse-grained overlap
@@ -489,14 +547,26 @@ Coarse-grained overlap lets the caller launch communication on a communication
stream and wait later.
```python
dispatch_req = moe_comm.dispatch_start(input, topk_ids, weights, scales)
dispatch_overlap_config = moe_comm.create_overlap_config(op="dispatch")
dispatch_req = moe_comm.dispatch_async(
input,
topk_ids,
weights,
scales,
overlap_config=dispatch_overlap_config,
)
# Run unrelated work while dispatch metadata/payload communication is in flight.
dispatch_out, handle = dispatch_req.wait()
expert_output = mlp(dispatch_out.tokens, dispatch_out.num_tokens_per_expert)
combine_req = moe_comm.combine_start(expert_output, handle)
combine_overlap_config = moe_comm.create_overlap_config(op="combine", handle=handle)
combine_req = moe_comm.combine_async(
expert_output,
handle,
overlap_config=combine_overlap_config,
)
# Run unrelated work while combine is in flight.
@@ -508,42 +578,48 @@ object should own any stream event or receive hook required by the backend.
### Fine-grained MLP/combine overlap
Fine-grained overlap sends combine data as soon as the MLP produces a block or
expert segment. This requires a device-side notify/signal from the MLP backend
to the combine kernel.
Fine-grained overlap sends combine data as soon as the MLP produces a block.
This requires a device-side notify/signal from the MLP backend to the combine
kernel.
```python
overlap = moe_comm.make_combine_overlap(handle, granularity="block")
combine_overlap_config = moe_comm.create_overlap_config(
op="combine",
handle=handle,
level="block",
)
# User must adapt the MLP backend/adapter to consume this config and notify
# combine as blocks become ready.
config = combine_overlap_config
expert_output = mlp(
dispatch_out.tokens,
dispatch_out.num_tokens_per_expert,
notify=overlap.notify,
config=config,
)
combine_req = moe_comm.combine_start(
combine_req = moe_comm.combine_async(
expert_output,
handle,
overlap=overlap,
overlap_config=combine_overlap_config,
)
output = combine_req.wait()
```
The notify object is not routing metadata. It only tells combine when a region
of `expert_output` is ready to read. The routing/source mapping still comes
from `handle`.
The overlap config is not routing metadata. It only tells combine when a
region of `expert_output` is ready to read. The routing/source mapping still
comes from `handle`.
The MLP backend must follow these rules when using notify:
- write `expert_output` in the same row/slot order as `dispatch_out.tokens`,
- publish data before signaling readiness,
- signal at the agreed granularity, such as block, expert segment, or whole
tensor,
- use the signal value/protocol provided by `overlap`.
- signal at the block granularity defined by `overlap_config`,
- use the signal value/protocol provided by `overlap_config`.
If the MLP backend does not support notify, it can still use the blocking API or
coarse-grained `combine_start` after the full `expert_output` tensor is ready.
coarse-grained `combine_async` after the full `expert_output` tensor is ready.
This must be a joint contract between the dispatcher and the MLP runner. The
dispatcher can provide the signal buffer and combine-side wait protocol, but it
@@ -555,13 +631,13 @@ arguments after dispatch, passes combine-side arguments to the DeepEP dispatcher
and passes down-GEMM arguments to the MoE runner. Backend support is selective:
- DeepGEMM FP8 masked down-GEMM can return block metadata such as `block_m` and
`threshold` and signal combine readiness.
`block_ready_value` and signal combine readiness.
- FlashInfer CuteDSL can receive down-GEMM signal/start-event arguments.
- Some paths, such as BF16 masked DeepGEMM and generic Triton runners, do not
support this block overlap protocol.
Therefore, the API should expose overlap as an optional capability advertised by
the MLP backend, not as a guaranteed feature of every `combine_start` call.
the MLP backend, not as a guaranteed feature of every `combine_async` call.
## Internal metadata exchange