From e4a40c721439940cacfcd2daaad11297e97809ad Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Thu, 29 May 2025 08:19:31 +0000 Subject: [PATCH] Add fp8 profiler instances --- .../device_operation_instance_factory.hpp | 3 ++ .../tensor_operation_instance/gpu/gemm_mx.hpp | 16 ++++----- ...device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp | 2 +- ...device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp | 2 +- ...device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp | 32 ++++++++--------- ...l_f8_f8_bf16_mk_nk_mn_default_instance.cpp | 4 +-- .../device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp | 32 ++++++++--------- ...dl_f8_f8_f16_mk_nk_mn_default_instance.cpp | 4 +-- .../include/profiler/profile_gemm_mx_impl.hpp | 36 +++++++++---------- profiler/src/profile_gemm_mx.cpp | 30 +++++++++++----- 10 files changed, 88 insertions(+), 73 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index f0bb35415b..7fe2f9653d 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -25,6 +25,9 @@ using BF8 = ck::bf8_t; using I4 = ck::pk_i4_t; using F4 = ck::f4x2_pk_t; +using E8M0 = ck::e8m0_bexp_t; +using E8M0PK = int32_t; + using Empty_Tuple = ck::Tuple<>; using BF16_Tuple = ck::Tuple; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp index 88f6238d8f..ec75a0cfb0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp @@ -22,9 +22,9 @@ void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances( Col, Row, F8, - e8m0_bexp_t, + E8M0PK, F8, - e8m0_bexp_t, + E8M0PK, F16, 32, PassThrough, @@ -36,9 +36,9 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances( Col, Row, F8, - e8m0_bexp_t, + E8M0PK, F8, - e8m0_bexp_t, + E8M0PK, BF16, 32, PassThrough, @@ -64,9 +64,9 @@ void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances( Row, Row, BF8, - e8m0_bexp_t, + E8M0PK, F8, - e8m0_bexp_t, + E8M0PK, F16, 32, PassThrough, @@ -78,9 +78,9 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances( Col, Row, F8, - e8m0_bexp_t, + E8M0PK, F8, - e8m0_bexp_t, + E8M0PK, BF16, 32, PassThrough, diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp index d7db026e51..cbfdb4bea0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp @@ -40,7 +40,7 @@ static constexpr auto ScaleBlockSize = 32; template using device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances = std::tuple< -#if 0 // TODO: Fix v1 +#if 0 // TODO: Fix RRR // clang-format off //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp index d965ac1b22..3969ece55a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp @@ -39,7 +39,7 @@ static constexpr auto ScaleBlockSize = 32; template using device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances = std::tuple< -#if 0 // TODO: Fix v1 +#if 0 // TODO: Fix CCR // clang-format off //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp index f1aa90b069..aa4704530d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp @@ -13,11 +13,12 @@ namespace tensor_operation { namespace device { namespace instance { -using F8 = f8_t; -using F16 = half_t; -using BF16 = bhalf_t; -using F32 = float; -using E8M0 = ck::e8m0_bexp_t; +using F8 = f8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; +using E8M0PK = int32_t; using Row = tensor_layout::gemm::RowMajor; using Col = tensor_layout::gemm::ColumnMajor; @@ -39,19 +40,18 @@ static constexpr auto ScaleBlockSize = 32; template using device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances = std::tuple< -#if 0 // TODO: Fix v1 // clang-format off - //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| - //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| - //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + //###########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //###########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //###########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 256, 16, 16, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, 256, 16, 16, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 256, 16, 16, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + std::nullptr_t // clang-format on -#endif >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp index 05914e06b5..a4c3451c47 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp @@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances( Col, Row, F8, - E8M0, + E8M0PK, F8, - E8M0, + E8M0PK, BF16, 32, PassThrough, diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp index c2142de177..1371e419ea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp @@ -13,11 +13,12 @@ namespace tensor_operation { namespace device { namespace instance { -using F8 = f8_t; -using F16 = half_t; -using BF16 = bhalf_t; -using F32 = float; -using E8M0 = ck::e8m0_bexp_t; +using F8 = f8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; +using E8M0PK = int32_t; using Row = tensor_layout::gemm::RowMajor; using Col = tensor_layout::gemm::ColumnMajor; @@ -39,19 +40,18 @@ static constexpr auto ScaleBlockSize = 32; template using device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances = std::tuple< -#if 0 // TODO: Fix v1 // clang-format off - //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| - //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| - //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + //###########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //###########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //###########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 256, 16, 16, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, 256, 16, 16, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 256, 16, 16, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + std::nullptr_t // clang-format on -#endif >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp index f4e59cf92d..1cacee7aea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp @@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances( Col, Row, F8, - E8M0, + E8M0PK, F8, - E8M0, + E8M0PK, F16, 32, PassThrough, diff --git a/profiler/include/profiler/profile_gemm_mx_impl.hpp b/profiler/include/profiler/profile_gemm_mx_impl.hpp index 46650b82fb..8135bf4475 100644 --- a/profiler/include/profiler/profile_gemm_mx_impl.hpp +++ b/profiler/include/profiler/profile_gemm_mx_impl.hpp @@ -126,9 +126,11 @@ bool profile_gemm_mx_impl(int do_verification, int n_iter, uint64_t rotating = 0) { - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; - using MFMA = ck::tensor_layout::gemm::MFMA; + using tensor_operation::device::instance::Col; + using tensor_operation::device::instance::E8M0; + using tensor_operation::device::instance::E8M0PK; + using tensor_operation::device::instance::MFMA; + using tensor_operation::device::instance::Row; constexpr bool BPreShuffle = is_same_v; using BRefLayout = conditional_t; @@ -138,11 +140,10 @@ bool profile_gemm_mx_impl(int do_verification, throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); }; - using XDataType = e8m0_bexp_t; + using XDataType = E8M0; + using XPackedDataType = E8M0PK; using AScaleLayout = Row; using BScaleLayout = Col; - using XPackedDataType = // TODO: use int32 for all - conditional_t, int32_t, e8m0_bexp_t>; auto f_host_tensor_descriptor = [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { @@ -225,7 +226,6 @@ bool profile_gemm_mx_impl(int do_verification, return ck::type_convert(x); }; - constexpr auto nthreads = 16; switch(init_method) { case 0: // Initializations for development and debugging @@ -245,23 +245,21 @@ bool profile_gemm_mx_impl(int do_verification, case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}, nthreads); // Z[-4,4] - b_k_n->GenerateTensorValue(GeneratorTensor_2{-4, 5}, nthreads); // Z[-4,4] + a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); // Z[-4,4] + b_k_n->GenerateTensorValue(GeneratorTensor_2{-4, 5}); // Z[-4,4] - a_m_k_scale.GenerateTensorValue(GeneratorTensor_2{125, 129}, - nthreads); // scales: {0.25, 0.5, 1, 2} - b_k_n_scale.GenerateTensorValue(GeneratorTensor_2{125, 129}, - nthreads); // scales: {0.25, 0.5, 1, 2} + a_m_k_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}, nthreads); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}, - nthreads); + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); - b_k_n->GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}, nthreads); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}, - nthreads); + b_k_n->GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); break; } diff --git a/profiler/src/profile_gemm_mx.cpp b/profiler/src/profile_gemm_mx.cpp index 78e1d24a26..9fd6f29464 100644 --- a/profiler/src/profile_gemm_mx.cpp +++ b/profiler/src/profile_gemm_mx.cpp @@ -18,7 +18,9 @@ enum struct GemmMatrixLayout enum struct GemmDataType { - F4_F4_F16, // 0 + F4_F4_F16, // 0 + F8_F8_F16, // 1 + F8_F8_BF16, // 2 }; #define OP_NAME "gemm_mx" @@ -29,10 +31,12 @@ int profile_gemm_mx(int argc, char* argv[]) if(argc != 11 && argc != 14 && argc != 18) { printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); - printf("arg2: data type (0: f4->f16)\n"); - printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); - printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); - printf(" 2: A[k, m] * BPreShuff = C[m, n];\n"); + printf("arg2: data type (0: f4->f16 ;\n"); + printf(" 1: fp8->f16 ;\n"); + printf(" 2: fp8->bf16 )\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n] ;\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n] ;\n"); + printf(" 2: A[k, m] * BPreShuff = C[m, n])\n"); printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n"); @@ -77,8 +81,10 @@ int profile_gemm_mx(int argc, char* argv[]) rotating = std::stoull(argv[arg_index++]) * 1024 * 1024; } - using F16 = ck::half_t; - using F4 = ck::f4x2_pk_t; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using F4 = ck::f4x2_pk_t; + using F8 = ck::f8_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -126,10 +132,18 @@ int profile_gemm_mx(int argc, char* argv[]) { return profile(F4{}, F4{}, F16{}, Row{}, Col{}, Row{}); } - if(data_type == GemmDataType::F4_F4_F16 && layout == GemmMatrixLayout::MK_MFMA_MN) + else if(data_type == GemmDataType::F4_F4_F16 && layout == GemmMatrixLayout::MK_MFMA_MN) { return profile(F4{}, F4{}, F16{}, Row{}, MFMA{}, Row{}); } + else if(data_type == GemmDataType::F8_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F8{}, F8{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F8{}, F8{}, BF16{}, Row{}, Col{}, Row{}); + } else { std::cout << "this data_type & layout is not implemented" << std::endl;