Fix the gfx950 numerical errors (#2911)

* Update grouped_gemm example and pipeline

* find the root cause error in did not enable the transpose in gfx950 correctly

* Fix v3 pipeline, row and col major

* Disable f8 datatype tests, it fails on gfx950

* fix the abd test by clear the runtime argument unsupported

---------

Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>
Co-authored-by: Mateusz Ozga <mateusz.ozga@amd.com>

[ROCm/composable_kernel commit: b159841a06]
This commit is contained in:
Thomas Ning
2025-09-23 22:54:52 -07:00
committed by GitHub
parent 5cc40c160f
commit 8a563fc79d
9 changed files with 63 additions and 123 deletions

View File

@@ -132,6 +132,10 @@ struct GemmKernelMultiABD
static constexpr index_t NumBTensor = BsDataType::size();
static constexpr index_t NumDTensor = DsDataType::size();
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>;
CK_TILE_HOST static auto GetName() -> const std::string
{
return UniversalGemmKernel::GetName();
@@ -181,6 +185,14 @@ struct GemmKernelMultiABD
{
return false;
}
// Currently MultiABD kernel doesn't support F8 data type
if(ck_tile::get_device_name() == "gfx950" &&
(std::is_same<ck_tile::fp8_t, ADataType>::value ||
std::is_same<ck_tile::fp8_t, BDataType>::value ||
std::is_same<ck_tile::fp8_t, DDataType>::value))
{
return false;
}
return UniversalGemmKernel::IsSupportedArgument(kargs);
}

View File

@@ -530,7 +530,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
__builtin_amdgcn_sched_barrier(0);
@@ -542,7 +543,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
block_sync_lds();
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -553,7 +554,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -577,7 +578,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
@@ -596,7 +598,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
if constexpr(is_a_col_major)
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
@@ -607,7 +609,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
}
if constexpr(is_b_row_major)
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
@@ -619,7 +621,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm.LocalPrefetch(
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// __builtin_amdgcn_sched_barrier(0);

View File

@@ -100,7 +100,7 @@ struct GemmPipelineProblemBase
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
if constexpr(std::is_same_v<AsLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
@@ -118,7 +118,7 @@ struct GemmPipelineProblemBase
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
if constexpr(std::is_same_v<BsLayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;