mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-06-30 11:27:30 +00:00
update
This commit is contained in:
@@ -79,7 +79,6 @@ class MoECommunicatorConfig:
|
||||
# Streams and overlap
|
||||
comm_stream: Optional[torch.cuda.Stream] = None
|
||||
enable_overlap: bool = False
|
||||
overlap_granularity: str = "none" # "none", "tensor", "expert", "block"
|
||||
```
|
||||
|
||||
The constructor can accept either a config object or keyword arguments:
|
||||
@@ -200,23 +199,24 @@ class MoECommunicator:
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
def dispatch_start(...) -> DispatchRequest:
|
||||
def dispatch_async(..., overlap_config: Optional[CommOverlapConfig] = None) -> DispatchRequest:
|
||||
...
|
||||
|
||||
def combine_start(...) -> CombineRequest:
|
||||
def combine_async(..., overlap_config: Optional[CommOverlapConfig] = None) -> CombineRequest:
|
||||
...
|
||||
|
||||
def make_combine_overlap(
|
||||
def create_overlap_config(
|
||||
self,
|
||||
handle: DispatchHandle,
|
||||
op: str, # "dispatch" or "combine"
|
||||
*,
|
||||
granularity: str = "block",
|
||||
) -> CombineOverlap:
|
||||
handle: Optional[DispatchHandle] = None,
|
||||
level: str = "op", # "op" or "block"
|
||||
) -> CommOverlapConfig:
|
||||
...
|
||||
```
|
||||
|
||||
The blocking `dispatch` and `combine` methods should be the default path. The
|
||||
`*_start` methods and `make_combine_overlap` are optional advanced APIs for
|
||||
`*_async` methods and `create_overlap_config` are optional advanced APIs for
|
||||
communication/computation overlap.
|
||||
|
||||
## High-level API
|
||||
@@ -263,8 +263,61 @@ class DispatchOutput:
|
||||
|
||||
class DispatchHandle:
|
||||
"""Opaque handle returned by dispatch and consumed by combine."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommOverlapConfig:
|
||||
op: str # "dispatch" or "combine"
|
||||
level: str = "op" # "op" or "block"
|
||||
stream: Optional[torch.cuda.Stream] = None
|
||||
wait_event: Optional[torch.cuda.Event] = None
|
||||
signal: Optional[torch.Tensor] = None
|
||||
num_comm_sms: Optional[int] = None
|
||||
block_m: Optional[int] = None
|
||||
block_ready_value: Optional[int] = None
|
||||
|
||||
```
|
||||
|
||||
`create_overlap_config` creates optional overlap configuration for async
|
||||
dispatch/combine calls.
|
||||
|
||||
```python
|
||||
dispatch_overlap_config = moe_comm.create_overlap_config(op="dispatch")
|
||||
combine_overlap_config = moe_comm.create_overlap_config(op="combine", handle=handle)
|
||||
```
|
||||
|
||||
Operation-level overlap does not require `create_overlap_config`; `dispatch_async`
|
||||
and `combine_async` can use their default stream/event behavior. Use
|
||||
`create_overlap_config` when the caller wants explicit stream/event/SM settings
|
||||
or block-level combine overlap.
|
||||
|
||||
For block-level MLP/combine overlap, the config includes the combine-side wait
|
||||
protocol and the device signal that an overlap-capable MLP backend must publish:
|
||||
|
||||
```python
|
||||
combine_overlap_config = moe_comm.create_overlap_config(
|
||||
op="combine",
|
||||
handle=handle,
|
||||
level="block",
|
||||
)
|
||||
```
|
||||
|
||||
`op="dispatch", level="block"` is not part of the first version. Dispatch
|
||||
overlap is operation-level only.
|
||||
|
||||
`CommOverlapConfig` fields:
|
||||
|
||||
| Field | Purpose |
|
||||
|---|---|
|
||||
| `op` | `"dispatch"` or `"combine"` |
|
||||
| `level` | `"op"` or `"block"` |
|
||||
| `stream` | Optional communication stream |
|
||||
| `wait_event` | Optional event the communication op waits on before starting |
|
||||
| `signal` | Device tensor written by MLP and waited on by combine for block overlap |
|
||||
| `num_comm_sms` | Optional SM budget for communication |
|
||||
| `block_m` | Rows per block for block overlap |
|
||||
| `block_ready_value` | Signal value that marks one block as ready for combine |
|
||||
|
||||
`DispatchHandle` should store the metadata needed to reverse dispatch:
|
||||
|
||||
- source rank and source token index,
|
||||
@@ -480,8 +533,13 @@ expert_output = mlp(dispatch_out.tokens, dispatch_out.num_tokens_per_expert)
|
||||
output = moe_comm.combine(expert_output, handle)
|
||||
```
|
||||
|
||||
For overlap, expose an optional split-phase API rather than adding many flags to
|
||||
the default path.
|
||||
For overlap, expose two optional APIs rather than adding many flags to the
|
||||
default path:
|
||||
|
||||
| API | Purpose |
|
||||
|---|---|
|
||||
| `dispatch_async` / `combine_async` | Coarse-grained async launch and wait |
|
||||
| `create_overlap_config(..., level="block")` | Fine-grained block notify between MLP down-GEMM and combine |
|
||||
|
||||
### Coarse-grained overlap
|
||||
|
||||
@@ -489,14 +547,26 @@ Coarse-grained overlap lets the caller launch communication on a communication
|
||||
stream and wait later.
|
||||
|
||||
```python
|
||||
dispatch_req = moe_comm.dispatch_start(input, topk_ids, weights, scales)
|
||||
dispatch_overlap_config = moe_comm.create_overlap_config(op="dispatch")
|
||||
dispatch_req = moe_comm.dispatch_async(
|
||||
input,
|
||||
topk_ids,
|
||||
weights,
|
||||
scales,
|
||||
overlap_config=dispatch_overlap_config,
|
||||
)
|
||||
|
||||
# Run unrelated work while dispatch metadata/payload communication is in flight.
|
||||
|
||||
dispatch_out, handle = dispatch_req.wait()
|
||||
expert_output = mlp(dispatch_out.tokens, dispatch_out.num_tokens_per_expert)
|
||||
|
||||
combine_req = moe_comm.combine_start(expert_output, handle)
|
||||
combine_overlap_config = moe_comm.create_overlap_config(op="combine", handle=handle)
|
||||
combine_req = moe_comm.combine_async(
|
||||
expert_output,
|
||||
handle,
|
||||
overlap_config=combine_overlap_config,
|
||||
)
|
||||
|
||||
# Run unrelated work while combine is in flight.
|
||||
|
||||
@@ -508,42 +578,48 @@ object should own any stream event or receive hook required by the backend.
|
||||
|
||||
### Fine-grained MLP/combine overlap
|
||||
|
||||
Fine-grained overlap sends combine data as soon as the MLP produces a block or
|
||||
expert segment. This requires a device-side notify/signal from the MLP backend
|
||||
to the combine kernel.
|
||||
Fine-grained overlap sends combine data as soon as the MLP produces a block.
|
||||
This requires a device-side notify/signal from the MLP backend to the combine
|
||||
kernel.
|
||||
|
||||
```python
|
||||
overlap = moe_comm.make_combine_overlap(handle, granularity="block")
|
||||
combine_overlap_config = moe_comm.create_overlap_config(
|
||||
op="combine",
|
||||
handle=handle,
|
||||
level="block",
|
||||
)
|
||||
|
||||
# User must adapt the MLP backend/adapter to consume this config and notify
|
||||
# combine as blocks become ready.
|
||||
config = combine_overlap_config
|
||||
expert_output = mlp(
|
||||
dispatch_out.tokens,
|
||||
dispatch_out.num_tokens_per_expert,
|
||||
notify=overlap.notify,
|
||||
config=config,
|
||||
)
|
||||
|
||||
combine_req = moe_comm.combine_start(
|
||||
combine_req = moe_comm.combine_async(
|
||||
expert_output,
|
||||
handle,
|
||||
overlap=overlap,
|
||||
overlap_config=combine_overlap_config,
|
||||
)
|
||||
|
||||
output = combine_req.wait()
|
||||
```
|
||||
|
||||
The notify object is not routing metadata. It only tells combine when a region
|
||||
of `expert_output` is ready to read. The routing/source mapping still comes
|
||||
from `handle`.
|
||||
The overlap config is not routing metadata. It only tells combine when a
|
||||
region of `expert_output` is ready to read. The routing/source mapping still
|
||||
comes from `handle`.
|
||||
|
||||
The MLP backend must follow these rules when using notify:
|
||||
|
||||
- write `expert_output` in the same row/slot order as `dispatch_out.tokens`,
|
||||
- publish data before signaling readiness,
|
||||
- signal at the agreed granularity, such as block, expert segment, or whole
|
||||
tensor,
|
||||
- use the signal value/protocol provided by `overlap`.
|
||||
- signal at the block granularity defined by `overlap_config`,
|
||||
- use the signal value/protocol provided by `overlap_config`.
|
||||
|
||||
If the MLP backend does not support notify, it can still use the blocking API or
|
||||
coarse-grained `combine_start` after the full `expert_output` tensor is ready.
|
||||
coarse-grained `combine_async` after the full `expert_output` tensor is ready.
|
||||
|
||||
This must be a joint contract between the dispatcher and the MLP runner. The
|
||||
dispatcher can provide the signal buffer and combine-side wait protocol, but it
|
||||
@@ -555,13 +631,13 @@ arguments after dispatch, passes combine-side arguments to the DeepEP dispatcher
|
||||
and passes down-GEMM arguments to the MoE runner. Backend support is selective:
|
||||
|
||||
- DeepGEMM FP8 masked down-GEMM can return block metadata such as `block_m` and
|
||||
`threshold` and signal combine readiness.
|
||||
`block_ready_value` and signal combine readiness.
|
||||
- FlashInfer CuteDSL can receive down-GEMM signal/start-event arguments.
|
||||
- Some paths, such as BF16 masked DeepGEMM and generic Triton runners, do not
|
||||
support this block overlap protocol.
|
||||
|
||||
Therefore, the API should expose overlap as an optional capability advertised by
|
||||
the MLP backend, not as a guaranteed feature of every `combine_start` call.
|
||||
the MLP backend, not as a guaranteed feature of every `combine_async` call.
|
||||
|
||||
## Internal metadata exchange
|
||||
|
||||
|
||||
Reference in New Issue
Block a user