mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Ud fix moe sorting gfx908 (#2720)
* Adding a ds permute fallback for the gfx908 and older for row_newbcast:7 instruction * Better macro for selecting ROW_NEWBCAST * clang-format the update --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -10,6 +10,26 @@
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#if !defined(CK_TILE_HAS_ROW_NEWBCAST)
|
||||
// row_newbcast (DPP modifier 0x157) support by architecture:
|
||||
// - Not supported: gfx908 (MI100) and older
|
||||
// - Supported: gfx90a (MI200), gfx94x (MI300), and all RDNA architectures
|
||||
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__gfx908__) || defined(__gfx906__) || defined(__gfx900__)
|
||||
// Explicitly disable for known unsupported architectures
|
||||
#define CK_TILE_HAS_ROW_NEWBCAST 0
|
||||
#else
|
||||
// Assume support for gfx90a and newer (including all gfx94x and RDNA)
|
||||
// This is safer as new architectures typically maintain backward compatibility
|
||||
#define CK_TILE_HAS_ROW_NEWBCAST 1
|
||||
#endif
|
||||
#else
|
||||
// Conservative default for non-AMD or host compilation
|
||||
#define CK_TILE_HAS_ROW_NEWBCAST 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
|
||||
@@ -380,18 +400,7 @@ struct MoeSortingKernel
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:8
|
||||
#if 0
|
||||
constexpr int bank_mask_0_7 = 0b1100;
|
||||
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
|
||||
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_update_dpp(0, /* old value */
|
||||
__builtin_bit_cast(int, thread_data),
|
||||
0x157,
|
||||
row_mask,
|
||||
bank_mask_0_7,
|
||||
bound_ctrl))// row_newbcast:7
|
||||
);
|
||||
#else
|
||||
#if CK_TILE_HAS_ROW_NEWBCAST
|
||||
data_t xxx =__builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x157,
|
||||
@@ -401,6 +410,17 @@ struct MoeSortingKernel
|
||||
|
||||
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
|
||||
thread_data = thread_data - yyy;
|
||||
#else
|
||||
// portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute
|
||||
// For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group
|
||||
int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group
|
||||
int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address
|
||||
int bcast7 = __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data));
|
||||
|
||||
// Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit)
|
||||
if ((__lane_id() / 8) % 2 != 0) { // Note: != 0, not == 0
|
||||
thread_data = thread_data - __builtin_bit_cast(data_t, bcast7);
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
@@ -1267,18 +1287,7 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data)
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:8
|
||||
#if 0
|
||||
constexpr int bank_mask_0_7 = 0b1100;
|
||||
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
|
||||
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_update_dpp(0, /* old value */
|
||||
__builtin_bit_cast(int, thread_data),
|
||||
0x157,
|
||||
row_mask,
|
||||
bank_mask_0_7,
|
||||
bound_ctrl))// row_newbcast:7
|
||||
);
|
||||
#else
|
||||
#if CK_TILE_HAS_ROW_NEWBCAST
|
||||
data_t xxx =
|
||||
__builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
@@ -1289,6 +1298,19 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data)
|
||||
|
||||
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
|
||||
thread_data = thread_data - yyy;
|
||||
#else
|
||||
// portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute
|
||||
// For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group
|
||||
int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group
|
||||
int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address
|
||||
int bcast7 =
|
||||
__builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data));
|
||||
|
||||
// Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit)
|
||||
if((__lane_id() / 8) % 2 != 0)
|
||||
{ // Note: != 0, not == 0
|
||||
thread_data = thread_data - __builtin_bit_cast(data_t, bcast7);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(wave_size > 8)
|
||||
|
||||
Reference in New Issue
Block a user