diff --git a/src/ext/ep/buffer.cc b/src/ext/ep/buffer.cc index 085404bc..51cd40bb 100644 --- a/src/ext/ep/buffer.cc +++ b/src/ext/ep/buffer.cc @@ -1707,7 +1707,8 @@ std::tuple, std::optional host_ts(kMaxBlocks * 4); + constexpr int kCombSlots = 8; + std::vector host_ts(kMaxBlocks * kCombSlots); cudaMemcpy(host_ts.data(), static_cast(workspace) + 196608, host_ts.size() * sizeof(uint64_t), @@ -1726,8 +1727,8 @@ std::tuple, std::optional, std::optional( reinterpret_cast(atomic_clean_flag) + 196608); #endif @@ -632,7 +634,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV; #ifdef MSCCLPP_EP_LL_PROFILE - if (thread_id == 0) prof_buf[sm_id * 4 + 0] = clock64(); + if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 0] = clock64(); #endif if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) { @@ -748,7 +750,7 @@ __global__ __launch_bounds__(kNumWarpGroups* kNumWarpsPerGroup * 32, 1) void com #ifdef MSCCLPP_EP_LL_PROFILE __syncthreads(); - if (thread_id == 0) prof_buf[sm_id * 4 + 1] = clock64(); + if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 1] = clock64(); #endif LOW_LATENCY_COMBINE_RECV: @@ -762,9 +764,12 @@ LOW_LATENCY_COMBINE_RECV: } #ifdef MSCCLPP_EP_LL_PROFILE __syncthreads(); - if (thread_id == 0) prof_buf[sm_id * 4 + 2] = clock64(); + if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 2] = clock64(); #endif cg::this_grid().sync(); +#ifdef MSCCLPP_EP_LL_PROFILE + if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 3] = clock64(); +#endif EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads); EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization"); @@ -803,7 +808,7 @@ LOW_LATENCY_COMBINE_RECV: } #ifdef MSCCLPP_EP_LL_PROFILE __syncthreads(); - if (thread_id == 0) prof_buf[sm_id * 4 + 3] = clock64(); + if (thread_id == 0) prof_buf[sm_id * kCombProfSlots + 4] = clock64(); #endif }