Addressing (Post Merge) code review comments for PR 1845 (#1883)

* Addressing code review comments.

* Addressing code review comments.

* Reorganized code for better readability.

* add ck_tile gemms for new types in CI

* fix jenkins syntax

* fix script syntax

* Add the test cases back

* Address the review comments

* Address review comments

* clang format

* Solve the merging issues

* Addressed the comments

* clang format

---------

Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

[ROCm/composable_kernel commit: 66c5f5b0b6]
This commit is contained in:
kylasa
2025-03-06 11:40:30 -08:00
committed by GitHub
parent 710aa99819
commit 7fbcd06a62
32 changed files with 511 additions and 245 deletions

View File

@@ -361,7 +361,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{
if constexpr(N == 2)
{
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), bit_cast<bf16x2_t>(x));
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
}
else if constexpr(N == 4)
{

View File

@@ -523,7 +523,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
int exponent = (x & 0x7F) >> SrcT_mant;
if constexpr(is_fnuz)
{
if(x == 0x80)
if((x & 0xff) == 0x80)
{
return fNaN;
}

View File

@@ -9,7 +9,9 @@
namespace ck_tile {
template <typename AccDataType_,
template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename ODataType_,
typename CLayout_,
index_t kBlockSize_,
@@ -23,6 +25,8 @@ template <typename AccDataType_,
bool isCTransposed_>
struct CShuffleEpilogueProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using CLayout = remove_cvref_t<CLayout_>;
@@ -40,9 +44,13 @@ struct CShuffleEpilogueProblem
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using Problem = remove_cvref_t<Problem_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ODataType, BDataType>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
@@ -56,8 +64,8 @@ struct CShuffleEpilogue
static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
using WG = WarpGemmMfmaDispatcher<ODataType,
ODataType,
using WG = WarpGemmMfmaDispatcher<ADataType,
BTypeToUse,
AccDataType,
kMPerXdl,
kNPerXdl,
@@ -77,7 +85,6 @@ struct CShuffleEpilogue
*
* @return The vector store size for C tensor.
*/
template <typename ODataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
constexpr index_t MaxVectorStoreSize = 16;
@@ -143,7 +150,7 @@ struct CShuffleEpilogue
TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration,
kNPerIteration,
GetVectorSizeC<ODataType>(),
GetVectorSizeC(),
tile_distribution_pattern::thread_raked>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();

View File

@@ -167,7 +167,7 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{
if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)
{
if(kargs.k_batch != 1)
@@ -275,7 +275,7 @@ struct GemmKernel
}
return false;
}
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
@@ -295,7 +295,7 @@ struct GemmKernel
}
return false;
}
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
@@ -407,7 +407,7 @@ struct GemmKernel
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
@@ -671,7 +671,7 @@ struct GemmKernel
}
else
{
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
@@ -694,7 +694,7 @@ struct GemmKernel
}
else
{
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm<memory_operation_enum::atomic_add>(

View File

@@ -33,8 +33,21 @@ struct BaseGemmPipelineAgBgCrCompV3
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
if(BlockHasHotloop(num_loop))
{
return TailNumber::Full;
}
else
{
if(num_loop == 1)
{
return TailNumber::Odd;
}
else
{
return TailNumber::Even;
}
}
}
};
@@ -470,6 +483,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
@@ -478,12 +492,43 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
{
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
else
{
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
return c_block_tile;
}

View File

@@ -143,7 +143,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
@@ -442,6 +442,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
Base::LocalPrefill(
b_copy_lds_window1, b_global_load_tile, b_element_func);
}
block_sync_lds();
Base::GlobalPrefetch(
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);

View File

@@ -26,6 +26,10 @@ struct GemmPipelineAGmemBGmemCRegV1
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
@@ -81,11 +85,21 @@ struct GemmPipelineAGmemBGmemCRegV1
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr bool is_a_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);