ext/ep: fix SWITCH_* macros and add missing standard headers

- Wrap SWITCH_* macros in launch.cuh in do { ... } while(false) so the
  trailing while(false) terminates the macro instead of dangling after
  the closing brace of the switch.
- Add #include <type_traits> to utils.cuh for std::remove_reference used
  in UNROLLED_WARP_COPY.
- Add #include <limits> to intranode_kernel.cu and internode.cu for
  std::numeric_limits.

Addresses Copilot review comments on PR #796.
This commit is contained in:
Qinghua Zhou
2026-05-06 03:18:39 +00:00
parent b2880652ce
commit c641487c55
4 changed files with 45 additions and 29 deletions

View File

@@ -1,5 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include <limits>
#include "configs.cuh"
#include "buffer.cuh"
#include "exception.cuh"

View File

@@ -1,5 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include <limits>
#include "configs.cuh"
#include "buffer.cuh"
#include "exception.cuh"

View File

@@ -19,45 +19,55 @@
#endif
#define SWITCH_RANKS(case_macro) \
switch (num_ranks) { \
case 2: case_macro(2); \
case 4: case_macro(4); \
case 8: case_macro(8); \
default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
do { \
switch (num_ranks) { \
case 2: case_macro(2); \
case 4: case_macro(4); \
case 8: case_macro(8); \
default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
} \
} while (false)
#define SWITCH_RDMA_RANKS(case_macro) \
switch (num_ranks / NUM_MAX_NVL_PEERS) { \
case 2: case_macro(2); \
case 3: case_macro(3); \
case 4: case_macro(4); \
case 8: case_macro(8); \
case 16: case_macro(16); \
case 18: case_macro(18); \
case 20: case_macro(20); \
default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
do { \
switch (num_ranks / NUM_MAX_NVL_PEERS) { \
case 2: case_macro(2); \
case 3: case_macro(3); \
case 4: case_macro(4); \
case 8: case_macro(8); \
case 16: case_macro(16); \
case 18: case_macro(18); \
case 20: case_macro(20); \
default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
} \
} while (false)
#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \
switch (num_ranks) { \
case 2: case_macro(dtype, 2); \
case 4: case_macro(dtype, 4); \
case 8: case_macro(dtype, 8); \
default: EP_HOST_ASSERT(false && "Unsupported ranks"); \
do { \
switch (num_ranks) { \
case 2: case_macro(dtype, 2); \
case 4: case_macro(dtype, 4); \
case 8: case_macro(dtype, 8); \
default: EP_HOST_ASSERT(false && "Unsupported ranks"); \
} \
} while (false)
#define SWITCH_TYPES(case_macro) \
switch (type) { \
case CUDA_R_16BF: case_macro(nv_bfloat16); \
case CUDA_R_32F: case_macro(float); \
default: EP_HOST_ASSERT(false && "Unsupported type"); \
do { \
switch (type) { \
case CUDA_R_16BF: case_macro(nv_bfloat16); \
case CUDA_R_32F: case_macro(float); \
default: EP_HOST_ASSERT(false && "Unsupported type"); \
} \
} while (false)
#define SWITCH_HIDDEN(case_macro) \
switch (hidden) { \
case 2560: case_macro(2560); \
case 4096: case_macro(4096); \
case 5120: case_macro(5120); \
case 7168: case_macro(7168); \
default: EP_HOST_ASSERT(false && "Unsupported hidden"); \
do { \
switch (hidden) { \
case 2560: case_macro(2560); \
case 4096: case_macro(4096); \
case 5120: case_macro(5120); \
case 7168: case_macro(7168); \
default: EP_HOST_ASSERT(false && "Unsupported hidden"); \
} \
} while (false)

View File

@@ -2,6 +2,8 @@
// Licensed under the MIT License.
#pragma once
#include <type_traits>
#include "exception.cuh"
#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \