diff --git a/src/ext/ep/kernels/internode.cu b/src/ext/ep/kernels/internode.cu index 167ab139..98d29809 100644 --- a/src/ext/ep/kernels/internode.cu +++ b/src/ext/ep/kernels/internode.cu @@ -1,5 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#include + #include "configs.cuh" #include "buffer.cuh" #include "exception.cuh" diff --git a/src/ext/ep/kernels/intranode_kernel.cu b/src/ext/ep/kernels/intranode_kernel.cu index 56fa3899..f6af5c66 100644 --- a/src/ext/ep/kernels/intranode_kernel.cu +++ b/src/ext/ep/kernels/intranode_kernel.cu @@ -1,5 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +#include + #include "configs.cuh" #include "buffer.cuh" #include "exception.cuh" diff --git a/src/ext/ep/kernels/launch.cuh b/src/ext/ep/kernels/launch.cuh index 79c9347a..94f9eb72 100644 --- a/src/ext/ep/kernels/launch.cuh +++ b/src/ext/ep/kernels/launch.cuh @@ -19,45 +19,55 @@ #endif #define SWITCH_RANKS(case_macro) \ - switch (num_ranks) { \ - case 2: case_macro(2); \ - case 4: case_macro(4); \ - case 8: case_macro(8); \ - default: EP_HOST_ASSERT(false and "Unsupported ranks"); \ + do { \ + switch (num_ranks) { \ + case 2: case_macro(2); \ + case 4: case_macro(4); \ + case 8: case_macro(8); \ + default: EP_HOST_ASSERT(false and "Unsupported ranks"); \ + } \ } while (false) #define SWITCH_RDMA_RANKS(case_macro) \ - switch (num_ranks / NUM_MAX_NVL_PEERS) { \ - case 2: case_macro(2); \ - case 3: case_macro(3); \ - case 4: case_macro(4); \ - case 8: case_macro(8); \ - case 16: case_macro(16); \ - case 18: case_macro(18); \ - case 20: case_macro(20); \ - default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \ + do { \ + switch (num_ranks / NUM_MAX_NVL_PEERS) { \ + case 2: case_macro(2); \ + case 3: case_macro(3); \ + case 4: case_macro(4); \ + case 8: case_macro(8); \ + case 16: case_macro(16); \ + case 18: case_macro(18); \ + case 20: case_macro(20); \ + default: EP_HOST_ASSERT(false and "Unsupported RDMA ranks"); \ + } \ } while (false) #define SWITCH_RANKS_WITH_DTYPE(dtype, case_macro) \ - switch (num_ranks) { \ - case 2: case_macro(dtype, 2); \ - case 4: case_macro(dtype, 4); \ - case 8: case_macro(dtype, 8); \ - default: EP_HOST_ASSERT(false && "Unsupported ranks"); \ + do { \ + switch (num_ranks) { \ + case 2: case_macro(dtype, 2); \ + case 4: case_macro(dtype, 4); \ + case 8: case_macro(dtype, 8); \ + default: EP_HOST_ASSERT(false && "Unsupported ranks"); \ + } \ } while (false) #define SWITCH_TYPES(case_macro) \ - switch (type) { \ - case CUDA_R_16BF: case_macro(nv_bfloat16); \ - case CUDA_R_32F: case_macro(float); \ - default: EP_HOST_ASSERT(false && "Unsupported type"); \ + do { \ + switch (type) { \ + case CUDA_R_16BF: case_macro(nv_bfloat16); \ + case CUDA_R_32F: case_macro(float); \ + default: EP_HOST_ASSERT(false && "Unsupported type"); \ + } \ } while (false) #define SWITCH_HIDDEN(case_macro) \ - switch (hidden) { \ - case 2560: case_macro(2560); \ - case 4096: case_macro(4096); \ - case 5120: case_macro(5120); \ - case 7168: case_macro(7168); \ - default: EP_HOST_ASSERT(false && "Unsupported hidden"); \ + do { \ + switch (hidden) { \ + case 2560: case_macro(2560); \ + case 4096: case_macro(4096); \ + case 5120: case_macro(5120); \ + case 7168: case_macro(7168); \ + default: EP_HOST_ASSERT(false && "Unsupported hidden"); \ + } \ } while (false) diff --git a/src/ext/ep/kernels/utils.cuh b/src/ext/ep/kernels/utils.cuh index c3cf66e2..70ca21a4 100644 --- a/src/ext/ep/kernels/utils.cuh +++ b/src/ext/ep/kernels/utils.cuh @@ -2,6 +2,8 @@ // Licensed under the MIT License. #pragma once +#include + #include "exception.cuh" #define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \