Ud fix moe sorting gfx908 (#2720)

* Adding a ds permute fallback for the gfx908 and older for row_newbcast:7 instruction

* Better macro for selecting ROW_NEWBCAST

* clang-format the update

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Michael Mcminn
2025-11-03 10:31:31 -05:00
committed by GitHub
parent d405641f06
commit afe1ff618d

View File

@@ -10,6 +10,26 @@
#include <string>
#include <type_traits>
#if !defined(CK_TILE_HAS_ROW_NEWBCAST)
// row_newbcast (DPP modifier 0x157) support by architecture:
// - Not supported: gfx908 (MI100) and older
// - Supported: gfx90a (MI200), gfx94x (MI300), and all RDNA architectures
#if defined(__HIP_DEVICE_COMPILE__) && defined(__HIP_PLATFORM_AMD__)
#if defined(__gfx908__) || defined(__gfx906__) || defined(__gfx900__)
// Explicitly disable for known unsupported architectures
#define CK_TILE_HAS_ROW_NEWBCAST 0
#else
// Assume support for gfx90a and newer (including all gfx94x and RDNA)
// This is safer as new architectures typically maintain backward compatibility
#define CK_TILE_HAS_ROW_NEWBCAST 1
#endif
#else
// Conservative default for non-AMD or host compilation
#define CK_TILE_HAS_ROW_NEWBCAST 0
#endif
#endif
namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
@@ -380,18 +400,7 @@ struct MoeSortingKernel
row_mask,
bank_mask,
bound_ctrl))); // row_shr:8
#if 0
constexpr int bank_mask_0_7 = 0b1100;
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
__builtin_amdgcn_update_dpp(0, /* old value */
__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask_0_7,
bound_ctrl))// row_newbcast:7
);
#else
#if CK_TILE_HAS_ROW_NEWBCAST
data_t xxx =__builtin_bit_cast(data_t,
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x157,
@@ -401,6 +410,17 @@ struct MoeSortingKernel
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
thread_data = thread_data - yyy;
#else
// portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute
// For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group
int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group
int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address
int bcast7 = __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data));
// Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit)
if ((__lane_id() / 8) % 2 != 0) { // Note: != 0, not == 0
thread_data = thread_data - __builtin_bit_cast(data_t, bcast7);
}
#endif
}
@@ -1267,18 +1287,7 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data)
row_mask,
bank_mask,
bound_ctrl))); // row_shr:8
#if 0
constexpr int bank_mask_0_7 = 0b1100;
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
__builtin_amdgcn_update_dpp(0, /* old value */
__builtin_bit_cast(int, thread_data),
0x157,
row_mask,
bank_mask_0_7,
bound_ctrl))// row_newbcast:7
);
#else
#if CK_TILE_HAS_ROW_NEWBCAST
data_t xxx =
__builtin_bit_cast(data_t,
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
@@ -1289,6 +1298,19 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data)
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
thread_data = thread_data - yyy;
#else
// portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute
// For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group
int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group
int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address
int bcast7 =
__builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data));
// Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit)
if((__lane_id() / 8) % 2 != 0)
{ // Note: != 0, not == 0
thread_data = thread_data - __builtin_bit_cast(data_t, bcast7);
}
#endif
}
if constexpr(wave_size > 8)