Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm example. (#1845)

* Support bf16/fb8/bf8 datatypes for ck_tile/gemm

* remove commented out code.

* Addressing code review comments and enabling universal_gemm for all the supported data types.

* Merge conflict resolution.

* Solve the memory pipeline compilation error. Merge with the new change of CShuffle

* finish the feature, pass the tests

* Fix the pipeline and add the benchmark script for other data types

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
kylasa
2025-02-06 14:07:38 -08:00
committed by Sam Wu
parent 9b5dfba242
commit ab5d027866
21 changed files with 598 additions and 88 deletions

View File

@@ -159,7 +159,7 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)
{
if(kargs.k_batch != 1)
@@ -240,7 +240,7 @@ struct GemmKernel
<< std::endl;
return false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false;
@@ -255,7 +255,7 @@ struct GemmKernel
<< std::endl;
return false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false;
@@ -321,7 +321,7 @@ struct GemmKernel
c_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
number<1>{});
}
else
@@ -519,7 +519,7 @@ struct GemmKernel
{
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm<memory_operation_enum::atomic_add>(