From e80e4bedba73d14f4fd566fcab5b9bcaa7b28554 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Mon, 2 Oct 2023 16:39:03 -0500 Subject: [PATCH] Add fp8 @ bf8 gemm support and example (#933) * Add f8 bf8 gemm example * Add element-wise ops * Add intrinsics * Update reference calculation * Add an additional type option for xdlops gemm * Fix build process * Add bf8 to buffer addressing * Update blockwise op, split typeA and typeB * Update for compatibility * Uppdate naming to f8->fp8 * Update naming * Format [ROCm/composable_kernel commit: bd09b5c5389c9f7f51459e5a38d3d7253b0c8dc0] --- example/01_gemm/CMakeLists.txt | 17 +- ..._xdl_fp16_f8.cpp => gemm_xdl_fp16_fp8.cpp} | 0 .../{gemm_xdl_f8.cpp => gemm_xdl_fp8.cpp} | 0 example/01_gemm/gemm_xdl_fp8_bf8.cpp | 49 ++++++ .../gpu/block/blockwise_gemm_xdlops.hpp | 127 ++++++++------- .../device/impl/device_gemm_xdl_cshuffle.hpp | 6 +- .../element/unary_element_wise_operation.hpp | 32 ++++ ...iple_d_welford_first_half_xdl_cshuffle.hpp | 1 + ...wise_batched_gemm_gemm_xdl_cshuffle_v1.hpp | 3 +- ...iple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp | 7 +- ...ultiple_d_softmax_gemm_xdl_cshuffle_v1.hpp | 3 +- ...ched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 3 +- ...e_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp | 1 + ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 1 + .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 1 + ...se_gemm_multiple_d_xdl_splitk_cshuffle.hpp | 1 + .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 1 + ...e_gemm_split_k_multiple_d_xdl_cshuffle.hpp | 2 + .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 24 +-- ...ridwise_gemm_xdl_layernorm_cshuffle_v1.hpp | 1 + ...ridwise_gemm_xdl_waveletmodel_cshuffle.hpp | 1 + .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 1 + .../gpu/grid/gridwise_gemm_xdlops_streamk.hpp | 1 + .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 2 + .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 1 + .../gpu/grid/gridwise_gemm_xdlops_v3r1.hpp | 1 + .../gpu/grid/gridwise_gemm_xdlops_v3r2.hpp | 1 + .../gpu/grid/gridwise_gemm_xdlops_v3r3.hpp | 1 + .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 98 ++++++++++-- include/ck/utility/amd_buffer_addressing.hpp | 147 +++++++++++------- include/ck/utility/amd_xdlops.hpp | 65 ++++++++ .../cpu/reference_gemm.hpp | 7 +- .../library/utility/host_tensor_generator.hpp | 4 +- .../gpu/grouped_gemm_fixed_nk/CMakeLists.txt | 24 ++- ...ixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp} | 0 ...ixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp} | 0 36 files changed, 475 insertions(+), 159 deletions(-) rename example/01_gemm/{gemm_xdl_fp16_f8.cpp => gemm_xdl_fp16_fp8.cpp} (100%) rename example/01_gemm/{gemm_xdl_f8.cpp => gemm_xdl_fp8.cpp} (100%) create mode 100644 example/01_gemm/gemm_xdl_fp8_bf8.cpp rename library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/{device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instance.cpp => device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp} (100%) rename library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/{device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instance.cpp => device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp} (100%) diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 2351d3f105..e0124e57f7 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -67,13 +67,20 @@ add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") - add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp) + add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) if(result EQUAL 0) - add_dependencies(example_gemm_xdl example_gemm_xdl_f8) + add_dependencies(example_gemm_xdl example_gemm_xdl_fp8) endif() endif() -add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) -if(result EQUAL 0) - add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) +if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") + add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) + if(result EQUAL 0) + add_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) + endif() +endif() + +add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) +if(result EQUAL 0) + add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) endif() diff --git a/example/01_gemm/gemm_xdl_fp16_f8.cpp b/example/01_gemm/gemm_xdl_fp16_fp8.cpp similarity index 100% rename from example/01_gemm/gemm_xdl_fp16_f8.cpp rename to example/01_gemm/gemm_xdl_fp16_fp8.cpp diff --git a/example/01_gemm/gemm_xdl_f8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp similarity index 100% rename from example/01_gemm/gemm_xdl_f8.cpp rename to example/01_gemm/gemm_xdl_fp8.cpp diff --git a/example/01_gemm/gemm_xdl_fp8_bf8.cpp b/example/01_gemm/gemm_xdl_fp8_bf8.cpp new file mode 100644 index 0000000000..0d69b7a90f --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8_bf8.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::bf8_t; +using CDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto LoopSched = ck::make_default_loop_scheduler(); +static constexpr auto PipelineVer = ck::PipelineVersion::v1; +using ComputeTypeA = ck::f8_t; +using ComputeTypeB = ck::bf8_t; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 1fee9c3225..904a96cc9f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -28,7 +28,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) } template {}; + static constexpr auto xdlops_gemm = XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; @@ -294,9 +295,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -318,25 +319,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 b_thread_buf); static_for<0, KPerThread, KPack>{}([&](auto k) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf + a_thread_vec.template AsType()(i) = a_thread_buf [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf + b_thread_vec.template AsType()(i) = b_thread_buf [Number{}]; }); - using mfma_input_type = - typename vector_type::type; + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); }); }); @@ -356,8 +359,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -366,8 +369,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1, A_K1>; - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -385,7 +388,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the // default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0 template struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { @@ -479,20 +485,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = + a_thread_vec.template AsType()(i) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = + b_thread_vec.template AsType()(i) = b_thread_buf[Number{}]; }); - using mfma_input_type = - typename vector_type::type; + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -514,8 +522,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // TODO: insert setprio in more precise manner since we // could have more than >1 MFMA instructions in single call xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) { @@ -541,8 +549,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, I1, I1, Number{})); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -551,8 +559,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1, A_K1>; - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -568,7 +576,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 }; template {}.K0PerXdlops, - index_t BMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops> +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename ATileDesc, + typename BTileDesc, + typename AMmaTileDesc, + typename BMmaTileDesc, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t MPerXDL, + index_t NPerXDL, + index_t MRepeat, + index_t NRepeat, + index_t KPack, + bool TransposeC = false, + index_t AMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops, + index_t BMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops> struct BlockwiseGemmXdlops_v2 { static constexpr auto I0 = Number<0>{}; @@ -654,7 +666,8 @@ struct BlockwiseGemmXdlops_v2 static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp index 86bc7b5bbb..693fd7048c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp @@ -66,7 +66,8 @@ template + typename ComputeTypeA = CDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGemm_Xdl_CShuffle : public DeviceGemm; + ComputeTypeA, + ComputeTypeB>; using Argument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index b181b66009..927574dfde 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -156,6 +156,38 @@ struct PassThrough y = type_convert(x); } #endif + +#if defined CK_ENABLE_BF8 + template <> + __host__ __device__ void operator()(bf8_t& y, const bf8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(float& y, const bf8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(half_t& y, const bf8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const half_t& x) const + { + y = type_convert(x); + } +#endif }; struct UnaryConvert diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index b25f136a37..206ea00b9d 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -522,6 +522,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ABDataType, + ABDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index c2f47bd444..9469fa7bc7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -628,7 +628,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle Gemm1KPack, false, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index d2920570e4..a0924ae3b0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -880,7 +880,12 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle Gemm1KPack, false, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{} + Gemm1KPack * XdlopsGemm{} .K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 18cfeebcf3..bc76d4cc4f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -794,7 +794,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle Gemm1KPack, true, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index f4b82badf1..afb2ad2e76 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -649,7 +649,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle Gemm1KPack, true, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index 3ced4b9ad6..9f5df8bdbf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -504,6 +504,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index d710fc1894..5c9f40b51a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -470,6 +470,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 1d920fb44d..c7ec435477 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -568,6 +568,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ComputeDataType, + ComputeDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index a3343bd3a0..4f37049b20 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -602,6 +602,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ComputeType, + ComputeType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 99e410f688..d75b631e61 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -457,6 +457,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 18cf80041b..e7dc0d3eb0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -588,6 +588,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ABDataType, + ABDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -1012,6 +1013,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ABDataType, + ABDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 9c09f3a539..d700314a26 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -108,7 +108,8 @@ template + typename ComputeTypeA = FloatC, + typename ComputeTypeB = ComputeTypeA> struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { static constexpr auto I0 = Number<0>{}; @@ -547,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(ComputeType) + - b_block_space_size_aligned * sizeof(ComputeType)), + return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) + + b_block_space_size_aligned * sizeof(ComputeTypeB)), c_block_size * sizeof(FloatCShuffle)); } @@ -750,7 +751,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, FloatA, - ComputeType, + ComputeTypeA, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -781,7 +782,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, FloatB, - ComputeType, + ComputeTypeB, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -809,13 +810,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeType, + ComputeTypeA, + ComputeTypeB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -833,10 +835,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 0404d88ab8..013120c540 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -495,6 +495,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index bbd01a238e..8675a9242a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -494,6 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< TileMathThreadGroupSize, ABDataType, + ABDataType, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 0920a17fc7..8897ce5a10 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -737,6 +737,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 @@ -502,10 +504,62 @@ struct mfma_type }; #endif -template +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f8bf8::Run(a, b, reg_c); + } +}; +#endif + +template struct MfmaSelector { - template + template static constexpr auto GetMfma(); template <> @@ -656,7 +710,22 @@ struct MfmaSelector } #endif - static constexpr auto selected_mfma = mfma_type()>{}; +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16f8bf8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32f8bf8; + } +#endif + + static constexpr auto selected_mfma = + mfma_type()>{}; __host__ __device__ constexpr MfmaSelector() { @@ -703,7 +772,8 @@ template + typename additional_type = base_type, + bool TransposeC = false> struct XdlopsGemm { static constexpr auto I0 = Number<0>{}; @@ -854,14 +924,18 @@ struct XdlopsGemm template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value + static_assert( + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value #if defined CK_ENABLE_FP8 - || is_same::value + || is_same::value #endif - , - "base base_type must be double, float, half, bfloat16, and int8_t!"); +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + || (is_same::value && is_same::value) +#endif + , + "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { if constexpr(!TransposeC) @@ -957,7 +1031,7 @@ struct XdlopsGemm return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; } - static constexpr auto mfma = MfmaSelector{}; + static constexpr auto mfma = MfmaSelector{}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 694027100f..d8094d84d2 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1127,37 +1127,53 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; -#if defined CK_ENABLE_FP8 - if constexpr(is_same::value) - { - auto tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); - return bit_cast(tmp); - } - else - { +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same::value || is_same::value) #endif - return amd_buffer_load_impl( - src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); -#if defined CK_ENABLE_FP8 - } +#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8 + if constexpr(is_same::value) +#endif +#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same::value) +#endif +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 + { + auto tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); + return bit_cast(tmp); + } + else + { +#endif + return amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 + } #endif #else -#if defined CK_ENABLE_FP8 - if constexpr(is_same::value) - { - auto tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); - return src_thread_element_valid ? bit_cast(tmp) : vector_t(0); - } - else - { +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same::value || is_same::value) #endif - vector_t tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); - return src_thread_element_valid ? tmp : vector_t(0); -#if defined CK_ENABLE_FP8 - } +#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8 + if constexpr(is_same::value) +#endif +#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same::value) +#endif +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 + { + auto tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + return src_thread_element_valid ? bit_cast(tmp) : vector_t(0); + } + else + { +#endif + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + return src_thread_element_valid ? tmp : vector_t(0); +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 + } #endif #endif } @@ -1216,40 +1232,61 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; -#if defined CK_ENABLE_FP8 - if constexpr(is_same::value) - { - auto tmp = - bit_cast::type::type>(src_thread_data); - amd_buffer_store_impl( - tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); - } - else - { + +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same::value || is_same::value) #endif - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); -#if defined CK_ENABLE_FP8 - } +#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8 + if constexpr(is_same::value) +#endif +#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same::value) +#endif +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 + { + auto tmp = bit_cast::type::type>( + src_thread_data); + amd_buffer_store_impl( + tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); + } + else + { +#endif + amd_buffer_store_impl(src_thread_data, + dst_wave_buffer_resource, + dst_addr_shift + + dst_thread_addr_offset, + 0); +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 + } #endif #else if(dst_thread_element_valid) { -#if defined CK_ENABLE_FP8 - if constexpr(is_same::value) - { - auto tmp = bit_cast::type::type>( - src_thread_data); - amd_buffer_store_impl( - tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } - else - { +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same::value || is_same::value) #endif - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); -#if defined CK_ENABLE_FP8 - } +#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8 + if constexpr(is_same::value) +#endif +#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same::value) +#endif +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 + { + auto tmp = + bit_cast::type::type>( + src_thread_data); + amd_buffer_store_impl( + tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } + else + { +#endif + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 + } #endif } #endif diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index a80540515a..3768f92633 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -419,5 +419,70 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> } }; #endif + +#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +template +struct intrin_mfma_f32_32x32x16f8bf8; + +template <> +struct intrin_mfma_f32_32x32x16f8bf8<32, 32> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32f8bf8; + +template <> +struct intrin_mfma_f32_16x16x32f8bf8<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; +#endif } // namespace ck #endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 95bd1e13d9..6e39dee71c 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -21,7 +21,8 @@ template + typename ComputeTypeA = ADataType, + typename ComputeTypeB = ComputeTypeA> struct ReferenceGemm : public device::BaseOperator { // Argument @@ -65,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator for(int k = 0; k < K; ++k) { - ComputType v_a; - ComputType v_b; + ComputeTypeA v_a; + ComputeTypeB v_b; // use PassThrough instead of ConvertBF16RTN for reference calculation if constexpr(is_same_v } }; -#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 +#if defined CK_ENABLE_FP8 template <> struct GeneratorTensor_2 { @@ -143,7 +143,7 @@ struct GeneratorTensor_3 } }; -#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 +#if defined CK_ENABLE_FP8 template <> struct GeneratorTensor_3 { diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt index 45f8130c57..51a42c3d8d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -1,10 +1,18 @@ -add_instance_library(device_grouped_gemm_fixed_nk_instance - device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp +set(GROUPED_GEMM_FIXED_NK_INSTANCES) - device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instance.cpp +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp) + list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp) +endif() - device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp -) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp) + list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp) +endif() + +if((DTYPES MATCHES "int8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp) + list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp) +endif() + +add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp