mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
ext/ep: fix SWITCH_* macros and add missing standard headers
- Wrap SWITCH_* macros in launch.cuh in do { ... } while(false) so the
trailing while(false) terminates the macro instead of dangling after
the closing brace of the switch.
- Add #include <type_traits> to utils.cuh for std::remove_reference used
in UNROLLED_WARP_COPY.
- Add #include <limits> to intranode_kernel.cu and internode.cu for
std::numeric_limits.
Addresses Copilot review comments on PR #796.
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
#include <limits>
|
||||
|
||||
#include "configs.cuh"
|
||||
#include "buffer.cuh"
|
||||
#include "exception.cuh"
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
#include <limits>
|
||||
|
||||
#include "configs.cuh"
|
||||
#include "buffer.cuh"
|
||||
#include "exception.cuh"
|
||||
|
||||
@@ -19,45 +19,55 @@
|
||||
#endif
|
||||
|
||||
#define SWITCH_RANKS(case_macro) \
|
||||
switch (num_ranks) { \
|
||||
case 2: case_macro(2); \
|
||||
case 4: case_macro(4); \
|
||||
case 8: case_macro(8); \
|
||||
default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
|
||||
do { \
|
||||
switch (num_ranks) { \
|
||||
case 2: case_macro(2); \
|
||||
case 4: case_macro(4); \
|
||||
case 8: case_macro(8); \
|
||||
default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define SWITCH_RDMA_RANKS(case_macro) \
|
||||
switch (num_ranks / NUM_MAX_NVL_PEERS) { \
|
||||
case 2: case_macro(2); \
|
||||
case 3: case_macro(3); \
|
||||
case 4: case_macro(4); \
|
||||
case 8: case_macro(8); \
|
||||
case 16: case_macro(16); \
|
||||
case 18: case_macro(18); \
|
||||
case 20: case_macro(20); \
|
||||
default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
|
||||
do { \
|
||||
switch (num_ranks / NUM_MAX_NVL_PEERS) { \
|
||||
case 2: case_macro(2); \
|
||||
case 3: case_macro(3); \
|
||||
case 4: case_macro(4); \
|
||||
case 8: case_macro(8); \
|
||||
case 16: case_macro(16); \
|
||||
case 18: case_macro(18); \
|
||||
case 20: case_macro(20); \
|
||||
default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \
|
||||
switch (num_ranks) { \
|
||||
case 2: case_macro(dtype, 2); \
|
||||
case 4: case_macro(dtype, 4); \
|
||||
case 8: case_macro(dtype, 8); \
|
||||
default: EP_HOST_ASSERT(false && "Unsupported ranks"); \
|
||||
do { \
|
||||
switch (num_ranks) { \
|
||||
case 2: case_macro(dtype, 2); \
|
||||
case 4: case_macro(dtype, 4); \
|
||||
case 8: case_macro(dtype, 8); \
|
||||
default: EP_HOST_ASSERT(false && "Unsupported ranks"); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define SWITCH_TYPES(case_macro) \
|
||||
switch (type) { \
|
||||
case CUDA_R_16BF: case_macro(nv_bfloat16); \
|
||||
case CUDA_R_32F: case_macro(float); \
|
||||
default: EP_HOST_ASSERT(false && "Unsupported type"); \
|
||||
do { \
|
||||
switch (type) { \
|
||||
case CUDA_R_16BF: case_macro(nv_bfloat16); \
|
||||
case CUDA_R_32F: case_macro(float); \
|
||||
default: EP_HOST_ASSERT(false && "Unsupported type"); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define SWITCH_HIDDEN(case_macro) \
|
||||
switch (hidden) { \
|
||||
case 2560: case_macro(2560); \
|
||||
case 4096: case_macro(4096); \
|
||||
case 5120: case_macro(5120); \
|
||||
case 7168: case_macro(7168); \
|
||||
default: EP_HOST_ASSERT(false && "Unsupported hidden"); \
|
||||
do { \
|
||||
switch (hidden) { \
|
||||
case 2560: case_macro(2560); \
|
||||
case 4096: case_macro(4096); \
|
||||
case 5120: case_macro(5120); \
|
||||
case 7168: case_macro(7168); \
|
||||
default: EP_HOST_ASSERT(false && "Unsupported hidden"); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "exception.cuh"
|
||||
|
||||
#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
|
||||
|
||||
Reference in New Issue
Block a user