[CK TILE] Fix double lds in ck tile gemm (#1924)

This commit is contained in:
Bartłomiej Kocot
2025-02-28 17:07:53 +01:00
committed by GitHub
parent faa2235dad
commit 1bf29478cd
2 changed files with 20 additions and 17 deletions

View File

@@ -654,11 +654,11 @@ struct GemmKernel
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
if(kargs.k_batch == 1)
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
__shared__ char smem_ptr_1[GetSmemSize()];
if(kargs.k_batch == 1)
{
__shared__ char smem_ptr_1[GetSmemSize()];
RunGemm2LDS(a_ptr,
b_ptr,
c_ptr,
@@ -671,19 +671,9 @@ struct GemmKernel
}
else
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
}
else
{
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
__shared__ char smem_ptr_1[GetSmemSize()];
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
b_ptr,
c_ptr,
@@ -694,7 +684,18 @@ struct GemmKernel
i_m,
i_n);
}
else
}
}
else
{
if(kargs.k_batch == 1)
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
else
{
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);