Merge commit 'dcd33a6ecc30e18cc8491ed03926ab5ac8b6f1c3' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-24 06:15:34 +00:00
parent a55a7e37ec
commit 167e5ab3b5
12 changed files with 121 additions and 142 deletions

View File

@@ -2570,6 +2570,60 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif
}
// amd_wave_read_first_lane is the SGPR function from AMD GPU device to load 1 or a series of the
// memory to the SGPR registers.
__device__ inline uint32_t amd_wave_read_first_lane(uint16_t v)
{
return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(v));
}
__device__ inline uint32_t amd_wave_read_first_lane(uint8_t v)
{
return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(v));
}
__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value)
{
return __builtin_amdgcn_readfirstlane(value);
}
__device__ inline int32_t amd_wave_read_first_lane(int32_t value)
{
return __builtin_amdgcn_readfirstlane(value);
}
template <typename Object, std::enable_if_t<std::is_trivially_copyable_v<Object>, int> = 0>
__device__ inline auto amd_wave_read_first_lane(const Object& obj)
{
constexpr size_t ObjectSize = sizeof(Object);
constexpr size_t SGPR_size = 4;
constexpr size_t NumFull = ObjectSize / SGPR_size;
constexpr size_t Tail = ObjectSize % SGPR_size;
const unsigned char* src = reinterpret_cast<const unsigned char*>(&obj);
alignas(Object) unsigned char dst[ObjectSize];
static_for<0, NumFull, 1>{}([&](auto Ic) {
constexpr size_t offset = Ic * SGPR_size;
uint32_t read_src;
__builtin_memcpy(&read_src, src + offset, SGPR_size);
read_src = __builtin_amdgcn_readfirstlane(read_src);
__builtin_memcpy(dst + offset, &read_src, SGPR_size);
});
if constexpr(Tail != 0)
{
constexpr size_t offset = NumFull * SGPR_size;
uint32_t tail_loc = 0;
__builtin_memcpy(&tail_loc, src + offset, Tail);
tail_loc = __builtin_amdgcn_readfirstlane(tail_loc);
__builtin_memcpy(dst + offset, &tail_loc, Tail);
}
Object out;
__builtin_memcpy(&out, dst, ObjectSize);
return out;
}
template <typename T, index_t NumElemsPerThread>
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset,

View File

@@ -158,7 +158,4 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<Windo
{
}
template <typename Tile>
concept IsLoadableTile = requires { load_tile(std::declval<Tile>()); };
} // namespace ck_tile

View File

@@ -481,13 +481,10 @@ struct CShuffleEpilogue
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
// Build windows only if scales are provided
// Build windows only if non-scalar scales are provided
auto scale_m_window = [&]() {
if constexpr(has_scales && !has_scalar_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
"ScaleM must be a loadable tile");
return make_tile_window(scale_m, dram_tile_distribution);
}
else
@@ -498,9 +495,6 @@ struct CShuffleEpilogue
auto scale_n_window = [&]() {
if constexpr(has_scales && !has_scalar_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
"ScaleN must be a loadable tile");
return make_tile_window(scale_n, dram_tile_distribution);
}
else
@@ -515,8 +509,8 @@ struct CShuffleEpilogue
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
// If scales provided, load them with identical distribution
if constexpr(has_scales && IsLoadableTile<ScaleM> && IsLoadableTile<ScaleN>)
// If non-scalar scales provided, load them with identical distribution
if constexpr(has_scales && !has_scalar_scales)
{
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
@@ -535,7 +529,7 @@ struct CShuffleEpilogue
{
v = static_cast<AccDataType>(v * scale_m * scale_n);
}
else if constexpr(has_scales)
else if constexpr(has_scales && !has_scalar_scales)
{
// same linear index mapping on the permuted distribution
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
@@ -636,9 +630,6 @@ struct CShuffleEpilogue
}
else if constexpr(has_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
"ScaleM must be a loadable tile");
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
}
else
@@ -653,9 +644,6 @@ struct CShuffleEpilogue
}
else if constexpr(has_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
"ScaleN must be a loadable tile");
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
}
else

View File

@@ -132,6 +132,10 @@ struct GemmKernelMultiABD
static constexpr index_t NumBTensor = BsDataType::size();
static constexpr index_t NumDTensor = DsDataType::size();
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>;
CK_TILE_HOST static auto GetName() -> const std::string
{
return UniversalGemmKernel::GetName();
@@ -181,6 +185,14 @@ struct GemmKernelMultiABD
{
return false;
}
// Currently MultiABD kernel doesn't support F8 data type
if(ck_tile::get_device_name() == "gfx950" &&
(std::is_same<ck_tile::fp8_t, ADataType>::value ||
std::is_same<ck_tile::fp8_t, BDataType>::value ||
std::is_same<ck_tile::fp8_t, DDataType>::value))
{
return false;
}
return UniversalGemmKernel::IsSupportedArgument(kargs);
}

View File

@@ -530,7 +530,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
__builtin_amdgcn_sched_barrier(0);
@@ -542,7 +543,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
block_sync_lds();
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -553,7 +554,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -577,7 +578,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
@@ -596,7 +598,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -607,7 +609,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -619,7 +621,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// __builtin_amdgcn_sched_barrier(0);

View File

@@ -100,7 +100,7 @@ struct GemmPipelineProblemBase
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
if constexpr(std::is_same_v<AsLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
@@ -118,7 +118,7 @@ struct GemmPipelineProblemBase
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
if constexpr(std::is_same_v<BsLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;