Addressing (Post Merge) code review comments for PR 1845 (#1883)

* Addressing code review comments.

* Addressing code review comments.

* Reorganized code for better readability.

* add ck_tile gemms for new types in CI

* fix jenkins syntax

* fix script syntax

* Add the test cases back

* Address the review comments

* Address review comments

* clang format

* Solve the merging issues

* Addressed the comments

* clang format

---------

Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
kylasa
2025-03-06 11:40:30 -08:00
committed by GitHub
parent c12fb0a624
commit 66c5f5b0b6
32 changed files with 511 additions and 245 deletions

View File

@@ -167,7 +167,7 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{
if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)
{
if(kargs.k_batch != 1)
@@ -275,7 +275,7 @@ struct GemmKernel
}
return false;
}
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
@@ -295,7 +295,7 @@ struct GemmKernel
}
return false;
}
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
@@ -407,7 +407,7 @@ struct GemmKernel
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
@@ -671,7 +671,7 @@ struct GemmKernel
}
else
{
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
@@ -694,7 +694,7 @@ struct GemmKernel
}
else
{
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm<memory_operation_enum::atomic_add>(