mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK TILE] Fix double lds in ck tile gemm (#1924)
[ROCm/composable_kernel commit: 1bf29478cd]
This commit is contained in:
@@ -654,11 +654,11 @@ struct GemmKernel
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
if(kargs.k_batch == 1)
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
@@ -671,19 +671,9 @@ struct GemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Do not compile in case where we have unsupported
|
||||
// VectorSizeC & data type configuration.
|
||||
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
@@ -694,7 +684,18 @@ struct GemmKernel
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm<memory_operation_enum::atomic_add>(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
|
||||
@@ -71,7 +71,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
// TODO: Restore to 8. At now after changes in block_universal_gemm_as_bs_cr it return wrong
|
||||
// values.
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
|
||||
Reference in New Issue
Block a user