diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 2918cd33bc..f6189c7495 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -10,6 +10,26 @@ #include #include +#if !defined(CK_TILE_HAS_ROW_NEWBCAST) +// row_newbcast (DPP modifier 0x157) support by architecture: +// - Not supported: gfx908 (MI100) and older +// - Supported: gfx90a (MI200), gfx94x (MI300), and all RDNA architectures + +#if defined(__HIP_DEVICE_COMPILE__) && defined(__HIP_PLATFORM_AMD__) +#if defined(__gfx908__) || defined(__gfx906__) || defined(__gfx900__) +// Explicitly disable for known unsupported architectures +#define CK_TILE_HAS_ROW_NEWBCAST 0 +#else +// Assume support for gfx90a and newer (including all gfx94x and RDNA) +// This is safer as new architectures typically maintain backward compatibility +#define CK_TILE_HAS_ROW_NEWBCAST 1 +#endif +#else +// Conservative default for non-AMD or host compilation +#define CK_TILE_HAS_ROW_NEWBCAST 0 +#endif +#endif + namespace ck_tile { #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ @@ -380,18 +400,7 @@ struct MoeSortingKernel row_mask, bank_mask, bound_ctrl))); // row_shr:8 -#if 0 - constexpr int bank_mask_0_7 = 0b1100; - auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; }; - thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t, - __builtin_amdgcn_update_dpp(0, /* old value */ - __builtin_bit_cast(int, thread_data), - 0x157, - row_mask, - bank_mask_0_7, - bound_ctrl))// row_newbcast:7 - ); -#else +#if CK_TILE_HAS_ROW_NEWBCAST data_t xxx =__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), 0x157, @@ -401,6 +410,17 @@ struct MoeSortingKernel data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx; thread_data = thread_data - yyy; +#else + // portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute + // For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group + int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group + int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address + int bcast7 = __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data)); + + // Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit) + if ((__lane_id() / 8) % 2 != 0) { // Note: != 0, not == 0 + thread_data = thread_data - __builtin_bit_cast(data_t, bcast7); + } #endif } @@ -1267,18 +1287,7 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data) row_mask, bank_mask, bound_ctrl))); // row_shr:8 -#if 0 - constexpr int bank_mask_0_7 = 0b1100; - auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; }; - thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t, - __builtin_amdgcn_update_dpp(0, /* old value */ - __builtin_bit_cast(int, thread_data), - 0x157, - row_mask, - bank_mask_0_7, - bound_ctrl))// row_newbcast:7 - ); -#else +#if CK_TILE_HAS_ROW_NEWBCAST data_t xxx = __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), @@ -1289,6 +1298,19 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data) data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx; thread_data = thread_data - yyy; +#else + // portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute + // For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group + int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group + int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address + int bcast7 = + __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data)); + + // Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit) + if((__lane_id() / 8) % 2 != 0) + { // Note: != 0, not == 0 + thread_data = thread_data - __builtin_bit_cast(data_t, bcast7); + } #endif } if constexpr(wave_size > 8)