diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index 3e8c9afd9d..11933f09a9 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -1,6 +1,12 @@ add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) + +add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_bf16 batched_gemm_gemm_wmma_cshuffle_v3_bf16.cpp) +add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp8 batched_gemm_gemm_wmma_cshuffle_v3_fp8.cpp) +add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp16 batched_gemm_gemm_wmma_cshuffle_v3_fp16.cpp) +add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_int8 batched_gemm_gemm_wmma_cshuffle_v3_int8.cpp) + if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_base.inc b/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_base.inc new file mode 100644 index 0000000000..c35b6be53e --- /dev/null +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_base.inc @@ -0,0 +1,276 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Gemm fused operation. Computes C_g_m_n = (A_g_m_k * B0_g_k_l) * B1_g_l_n + |------------------| + Gemm0 + |-----------------------------| + Gemm1 +*/ + +static constexpr auto PipeSched = ck::BlockGemmPipelineScheduler::Interwave; +static constexpr auto PipelineVer = ck::BlockGemmPipelineVersion::v1; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +template +using S = ck::Sequence; + +// clang-format off +// #define CK_MHA_USE_RCCR_LAYOUT +#define CK_MHA_USE_WAVE_1 +// #define CK_MHA_USE_WAVE_2 +// #define CK_MHA_USE_WAVE_4 +// #define CK_MHA_USE_WAVE_8 + +#ifdef CK_MHA_USE_RCCR_LAYOUT +using DeviceMHAFactory = + std::tuple< + ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3< + Row, Col, Col, Row, + ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, + 32, + // Gemm 0 + 16, 64, 64, 64, 64, 8, 8, + // Gemm 1 + 8, + 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + PipeSched, PipelineVer> + >; +#else +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3< + Row, Col, Row, Row, + ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, + 32, + // Gemm 0 + 16, 128, 64, 64, 64, 8, 8, + // Gemm 1 + 8, + 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + PipeSched, PipelineVer>, + ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3< + Row, Col, Row, Row, + ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, + 32, + // Gemm 0 + 16, 64, 64, 64, 64, 8, 8, + // Gemm 1 + 8, + 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + PipeSched, PipelineVer> +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3< + Row, Col, Row, Row, + ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, + 64, + // Gemm 0 + 32, 128, 64, 64, 64, 8, 8, + // Gemm 1 + 8, + 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + PipeSched, PipelineVer>, + ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3< + Row, Col, Row, Row, + ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, + 64, + // Gemm 0 + 32, 64, 64, 64, 64, 8, 8, + // Gemm 1 + 8, + 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + PipeSched, PipelineVer> +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3< + Row, Col, Row, Row, + ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, + 128, + // Gemm 0 + 64, 128, 64, 64, 64, 8, 8, + // Gemm 1 + 8, + 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + PipeSched, PipelineVer>, + ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3< + Row, Col, Row, Row, + ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, + 128, + // Gemm 0 + 64, 64, 64, 64, 64, 8, 8, + // Gemm 1 + 8, + 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + PipeSched, PipelineVer> +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3< + Row, Col, Row, Row, + ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, + 256, + // Gemm 0 + 128, 128, 64, 64, 64, 8, 8, + // Gemm 1 + 8, + 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + PipeSched, PipelineVer>, + ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3< + Row, Col, Row, Row, + ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, + 256, + // Gemm 0 + 128, 128, 64, 64, 64, 8, 8, + // Gemm 1 + 8, + 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + PipeSched, PipelineVer> +#endif + >; +#endif + +// clang-format on +// Ref Gemm0 +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Gemm1 +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_gemm_wmma_cshuffle_v3.inc" + +int main(int argc, char* argv[]) +{ + bool is_supported = ck::is_gfx11_supported() || ck::is_gfx12_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + return run(argc, argv); +} diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_bf16.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_bf16.cpp new file mode 100644 index 0000000000..d1a8e7f30c --- /dev/null +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_bf16.cpp @@ -0,0 +1,37 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" + +using BF16 = ck::bhalf_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using B0DataType = BF16; +using B1DataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = BF16; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc" diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_fp16.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_fp16.cpp new file mode 100644 index 0000000000..57b95b7412 --- /dev/null +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_fp16.cpp @@ -0,0 +1,37 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc" diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_fp8.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_fp8.cpp new file mode 100644 index 0000000000..82ac374efe --- /dev/null +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_fp8.cpp @@ -0,0 +1,34 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::f8_t; +using B0DataType = ck::f8_t; +using B1DataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = ck::f8_t; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc" diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_int8.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_int8.cpp new file mode 100644 index 0000000000..44d63a0d43 --- /dev/null +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_int8.cpp @@ -0,0 +1,34 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using B0DataType = int8_t; +using B1DataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using CDataType = int8_t; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc" diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc new file mode 100644 index 0000000000..8ab47c2925 --- /dev/null +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 113; +#ifdef CK_MHA_USE_RCCR_LAYOUT + ck::index_t N = 480; // Must be multiple of 8 even with padding. +#else + ck::index_t N = 477; +#endif + ck::index_t K = 200; // Must be multiple of 8 even with padding. + ck::index_t O = 208; // Must be multiple of 8 even with padding. + ck::index_t G = 91; // Batch + + float alpha = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G = std::stoi(argv[8]); + + alpha = std::stof(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 8: M, N, K, O, G\n"); + printf("arg9: scale (alpha)\n"); + exit(0); + } + + std::vector a_g_m_k_lengths{G, M, K}; + std::vector a_g_m_k_strides{M * K, K, 1}; // A layout [G, M, K] + std::vector b0_g_n_k_lengths{G, N, K}; + std::vector b0_g_n_k_strides{N * K, K, 1}; // B0 layout [G, N, K] + std::vector b1_g_o_n_lengths{G, O, N}; +#ifdef CK_MHA_USE_RCCR_LAYOUT + std::vector b1_g_o_n_strides{N * O, N, 1}; // B1 layout [G, O, N] +#else + std::vector b1_g_o_n_strides{N * O, 1, O}; // B1 layout [G, N, O] +#endif + std::vector c_g_m_o_lengths{G, M, O}; + std::vector c_g_m_o_strides{M * O, O, 1}; // C layout [G, M, O] + + Tensor a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides); + Tensor b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides); + Tensor b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides); + Tensor c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides); + Tensor c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b0_g_n_k: " << b0_g_n_k.mDesc << std::endl; + std::cout << "b1_g_o_n: " << b1_g_o_n.mDesc << std::endl; + std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_o_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_g_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_g_o_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_g_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_g_o_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_g_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_o_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_o_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_g_o_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_g_o_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_g_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_o_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_g_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_g_o_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_g_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_g_o_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b0_device_buf.ToDevice(b0_g_n_k.mData.data()); + b1_device_buf.ToDevice(b1_g_o_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker_ptr = gemm.MakeInvokerPointer(); + auto argument_ptr = + gemm.MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + G, // Batch, + a_g_m_k_strides[1], // StrideA, + b0_g_n_k_strides[1], // StrideB0, +#ifdef CK_MHA_USE_RCCR_LAYOUT + b1_g_o_n_strides[1], // StrideB1, +#else + b1_g_o_n_strides[2], // StrideB1, +#endif + c_g_m_o_strides[1], // StrideC, + a_g_m_k_strides[0], // BatchStrideA + b0_g_n_k_strides[0], // BatchStrideB0 + b1_g_o_n_strides[0], // BatchStrideB1 + c_g_m_o_strides[0], // BatchStrideC + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument_ptr.get())) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + return; + } + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + G; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); + + Tensor b0_g_k_n({G, K, N}); + Tensor b1_g_n_o({G, N, O}); + Tensor acc0_g_m_n({G, M, N}); // scratch object after gemm0 + Tensor a1_g_m_n({G, M, N}); // scratch object after conversion + + // permute + b0_g_n_k.ForEach( + [&](auto& self, auto idx) { b0_g_k_n(idx[0], idx[2], idx[1]) = self(idx); }); + b1_g_o_n.ForEach( + [&](auto& self, auto idx) { b1_g_n_o(idx[0], idx[2], idx[1]) = self(idx); }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + // Passthrough instead of softmax, DOES involve data type conversion. + a1_g_m_n(idx) = ck::type_convert(self(idx)); + }); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_g_m_o_device_result.mData, + c_g_m_o_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MHA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: G: " << G << ", M: " << M << ", N: " << N << ", K: " << K + << ", O: " << O << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp index bfb081330c..8cff087ddb 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp @@ -27,7 +27,8 @@ template + index_t KPack, + bool TransposeC = false> constexpr auto BlockGemmPipeline_Selector() { if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) @@ -50,7 +51,8 @@ constexpr auto BlockGemmPipeline_Selector() NPerWmma, MRepeat, NRepeat, - KPack>{}; + KPack, + TransposeC>{}; } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { @@ -72,7 +74,8 @@ constexpr auto BlockGemmPipeline_Selector() NPerWmma, MRepeat, NRepeat, - KPack>{}; + KPack, + TransposeC>{}; } else { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index 6fb62bc677..28c871ae0d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -277,6 +277,21 @@ struct BlockwiseGemmWmmaops_pipeline_base "wrong!"); } + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); + } + __host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index f25648efa6..76d748eb27 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -31,7 +31,8 @@ template + index_t KPack, + bool TransposeC = false> struct BlockwiseGemmWmmaops_pipeline_v1 { }; @@ -53,7 +54,8 @@ template + index_t KPack, + bool TransposeC> struct BlockwiseGemmWmmaops_pipeline_v1 + KPack, + TransposeC> : BlockwiseGemmWmmaops_pipeline_base - + KPack, + TransposeC> { using Base = BlockwiseGemmWmmaops_pipeline_base; + KPack, + TransposeC>; using Base::I0; using Base::A_K1; @@ -329,7 +333,8 @@ template + index_t KPack, + bool TransposeC> struct BlockwiseGemmWmmaops_pipeline_v1 + KPack, + TransposeC> : BlockwiseGemmWmmaops_pipeline_base - + KPack, + TransposeC> { using Base = BlockwiseGemmWmmaops_pipeline_base; + KPack, + TransposeC>; using Base::I0; using Base::I1; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index 8fed23d151..83dadb2175 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -31,7 +31,8 @@ template + index_t KPack, + bool TransposeC = false> struct BlockwiseGemmWmmaops_pipeline_v3 { }; @@ -53,7 +54,8 @@ template + index_t KPack, + bool TransposeC> struct BlockwiseGemmWmmaops_pipeline_v3 + KPack, + TransposeC> : BlockwiseGemmWmmaops_pipeline_base + KPack, + TransposeC> { using Base = BlockwiseGemmWmmaops_pipeline_base; + KPack, + TransposeC>; using Base::I0; using Base::A_K1; @@ -128,6 +133,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3(num_loop % num_loop_per_scale == 0); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, Number{}))>{}]; + }); + static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, n0, k0, I0, I0, Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + } + + // Tail, always perform. { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..32669800cd --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -0,0 +1,788 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run( + arg.p_a_grid + a_batch_offset, + arg.p_b0_grid + b0_batch_offset, + arg.p_b1_grid + b1_batch_offset, + arg.p_c_grid + c_batch_offset, + p_shared, + arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op, + arg.b0_element_op, + arg.acc_element_op, + arg.b1_element_op, + arg.c_element_op, + arg.block_2_ctile_map); +#else + ignore = arg; +#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) +} + +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm +{ + using DeviceOp = DeviceBatchedGemmGemm_Wmma_CShuffleV3; + + static constexpr auto I0 = Number<0>{}; + + // To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal + // to LPerWmma (A.k.a Gemm0 NPerWmma). + static constexpr index_t NPerWmma = LPerWmma; + + // TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler + // Transform operator or just not use one at all. + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence<1, 1, 1, 1, 1>, + Sequence, + GemmSpec, + TensorSpecialization::Default, // ASpec + TensorSpecialization::Default, // B0Spec + TensorSpecialization::Default, // B1Spec + TensorSpecialization::Default>; // CSpec + + __host__ __device__ static auto + MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, + const std::array& a_g_m_k_strides_vec) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, + const std::array& b0_g_l_k_strides_vec) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, + const std::array& b1_g_n_l_strides_vec) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), + Number{}); + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB0, + index_t BatchStrideB1, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB0_(BatchStrideB0), + BatchStrideB1_(BatchStrideB1), + BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB0_); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB0_; + index_t BatchStrideB1_; + index_t BatchStrideC_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3< + // DataType Family + ADataType, + B0DataType, + AccDataType, // Acc0DataType + B1DataType, + AccDataType, // Acc1DataType + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + BlkGemmPipeSched, + BlkGemmPipelineVer>; + + struct RawArg : public BaseArgument + { + using arr3 = std::array; + + RawArg(const ADataType* p_a_grid_, + const B0DataType* p_b0_grid_, + const B1DataType* p_b1_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t O_, + index_t Batch, + index_t StrideA, + index_t StrideB0, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB0, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op_, + B0ElementwiseOperation b0_element_op_, + AccElementwiseOperation acc_element_op_, + B1ElementwiseOperation b1_element_op_, + CElementwiseOperation c_element_op_) + : p_a_grid{p_a_grid_}, + p_b0_grid{p_b0_grid_}, + p_b1_grid{p_b1_grid_}, + p_c_grid{p_c_grid_}, + M{M_}, + N{N_}, + K{K_}, + O{O_}, + batch_count{Batch}, + a_element_op{a_element_op_}, + b0_element_op{b0_element_op_}, + acc_element_op{acc_element_op_}, + b1_element_op{b1_element_op_}, + c_element_op{c_element_op_}, + compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC} + { + + a_g_m_k_lengths = arr3{batch_count, M, K}; + a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] + + b0_g_n_k_lengths = arr3{batch_count, N, K}; + b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] + + b1_g_o_n_lengths = arr3{batch_count, O, N}; + b1_g_o_n_strides = + is_same_v + ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] + : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] + + c_g_m_o_lengths = arr3{batch_count, M, O}; + c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O] + + a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides); + b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides); + b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides); + c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides); + c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + } + // Pointers + const ADataType* p_a_grid; + const B0DataType* p_b0_grid; + const B1DataType* p_b1_grid; + CDataType* p_c_grid; + + // Raw Problem Size + index_t M; + index_t N; + index_t K; + index_t O; + index_t batch_count; + + arr3 a_g_m_k_lengths; + arr3 a_g_m_k_strides; + arr3 b0_g_n_k_lengths; + arr3 b0_g_n_k_strides; + arr3 b1_g_o_n_lengths; + arr3 b1_g_o_n_strides; + arr3 c_g_m_o_lengths; + arr3 c_g_m_o_strides; + + AElementwiseOperation a_element_op; + B0ElementwiseOperation b0_element_op; + AccElementwiseOperation acc_element_op; + B1ElementwiseOperation b1_element_op; + CElementwiseOperation c_element_op; + + // Grid descriptors and other mem calculators + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n; + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock; + + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map; + + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; + }; + + static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg) + { + // Print lambda with env check and printf() style formmating. + const char* curFunc = __func__; + auto print = [&curFunc](const char* format, ...) -> void { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" +#endif + va_list args; + va_start(args, format); + std::vfprintf(stdout, format, args); + va_end(args); +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; + } + }; + + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + print("DeviceOp: Arch err\n"); + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + print("DeviceOp: gfx 11 does not support fp8\n"); + return false; + } + } + + if constexpr(!(is_same_v || is_same_v)) + { + print("DeviceOp: Acc0 Type err\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: A layout must be Row\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: B layout must be Column\n"); + return false; + } + + if constexpr(!(is_same_v || + is_same_v)) + { + print("DeviceOp: B1 layout must be Column or Row\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: C layout must be Row\n"); + return false; + } + + // Other padding modes have not been tested and do not get checked individually. + if constexpr(GemmSpec != GemmSpecialization::Default && + GemmSpec != GemmSpecialization::MNKOPadding) + { + print("Padding mode must be default or MNKO\n"); + return false; + } + + // Per wmma dimensions not equal to 16 are very untested. + if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16) + { + print("M, L, N per Wmma must be 16\n"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n, + arg.block_2_ctile_map)) + { + return false; + } + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; + const auto c_extent_lowest = arg.O; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + print("DeviceOp: Data Transfer Vector scalar err\n"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1]; + const auto c_stride_lowest = arg.c_g_m_o_strides[2]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + print("DeviceOp: Data Vectorize transfer err\n"); + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); + + const index_t grid_size = arg.batch_count * M0 * N0; + + auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { + constexpr bool has_loop = decltype(has_main_k_block_loop)::value; + constexpr TailNumber tn = tail_number; + + const auto kernel = + kernel_batched_gemm_gemm_wmma_cshuffle_v3; + + return launch_and_time_kernel( + stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); + }; + + bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K); + TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); + return 0.0f; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); + return 0.0f; + } + } + else + { + printf("Invalid pipeline version!\n"); + return 0.0f; + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t O, + ck::index_t Batch, + ck::index_t StrideA, + ck::index_t StrideB0, + ck::index_t StrideB1, + ck::index_t StrideC, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB0, + ck::index_t BatchStrideB1, + ck::index_t BatchStrideC, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + M, + N, + K, + O, + Batch, + StrideA, + StrideB0, + StrideB1, + StrideC, + BatchStrideA, + BatchStrideB0, + BatchStrideB1, + BatchStrideC, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + template + static constexpr const char* DataTypeToString() + { + if constexpr(std::is_same_v) + { + return "fp32"; + } + else if constexpr(std::is_same_v) + { + return "fp16"; + } + else if constexpr(std::is_same_v) + { + return "bf16"; + } + else if constexpr(std::is_same_v) + { + return "fp8"; + } + else if constexpr(std::is_same_v) + { + return "bf8"; + } + else if constexpr(std::is_same_v) + { + return "int32"; + } + else if constexpr(std::is_same_v) + { + return "int8"; + } + else if constexpr(std::is_same_v) + { + return "int4"; + } + else + { + return "unknown"; + } + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceBatchedGemmGemm_Wmma_CShuffleV3" + << "<" + << ALayout::name[0] + << B0layout::name[0] + << B1Layout::name[0] + << CLayout::name[0] << ", " + << "A " << DataTypeToString() << ", " + << "B0 " << DataTypeToString() << ", " + << "B1 " << DataTypeToString() << ", " + << "C " << DataTypeToString() << ", " + << "Acc " << DataTypeToString() << ", " + << "Cshuf " << DataTypeToString() << ", " + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " + << "BlkGemmPipelinePrefetchStages: " + << GridwiseOp::BlockwiseGemmPipe::PrefetchStages; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..b61c7a09eb --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -0,0 +1,1127 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck/utility/env.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L] +// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N] +template +struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto AK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto BK1 = Number{}; + + static constexpr auto L0PerBlock = LTilePerBlock / L1Value; + static constexpr auto AL0 = Number{}; // TODO: Where does this 2 come from? + static constexpr auto AL1 = Number{}; + static constexpr auto BL0 = Number{}; + static constexpr auto BL1 = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + // TODO: I am pretty sure this is always 16 and *should* always be 16. + static constexpr auto KPack = + math::integer_least_multiple(math::integer_least_multiple(AK1Value, BK1Value), 16); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ __device__ static constexpr auto MakeABlockDescriptor() + { + constexpr auto a_block_desc = [&]() { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / AK1; + constexpr auto max_lds_align = AK1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeB0BlockDescriptor() + { + constexpr auto b0_block_desc = [&]() { + // K0->L->BK1 Per Block + constexpr auto K0PerBlock = KPerBlock / BK1; + constexpr auto max_lds_align = BK1; + + if constexpr(B0BlockLdsExtraL) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BK1), max_lds_align); + } + }(); + + return b0_block_desc; + } + + __host__ __device__ static constexpr auto MakeB1BlockDescriptor() + { + constexpr auto b1_block_desc = [&]() { + // L0->N->BL1 Per Block + constexpr auto max_lds_align = BL1; + + if constexpr(B1BlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BL1), + make_tuple(Number{} * BL1, BL1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BL1), max_lds_align); + } + }(); + + return b1_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { return make_multi_index(AK0, 0, 0); }(); + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeB0BlockSliceCopyStep() + { + constexpr auto b0_block_copy_step = [&]() { return make_multi_index(BK0, 0, 0); }(); + return b0_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeB1BlockSliceCopyStep() + { + constexpr auto b1_block_copy_step = [&]() { return make_multi_index(L0PerBlock, 0, 0); }(); + return b1_block_copy_step; + } + + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + constexpr auto a_wave_desc = [&]() { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else + constexpr auto A_KRow = I1; +#endif + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&) + { + constexpr auto b0_wave_desc = [&]() { + // BK0_L_BK1 -> BK0_LRepeat_Lwaves_BKRow_LPerWmma_BK1 + constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else + constexpr auto B_KRow = I1; +#endif + return transform_tensor_descriptor( + B0BlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + }(); + + return b0_wave_desc; + } + + template + __host__ __device__ static constexpr auto + MakeA1WaveDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&) + { + constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0); + constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2); + constexpr auto A_LRow = I1; + return transform_tensor_descriptor( + A1BlockDesc_AL0_M_AL1{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_LRow)), + make_unmerge_transform(make_tuple(Number{}, I1, I1)), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + template + __host__ __device__ static constexpr auto MakeB1WaveDescriptor(const B1BlockDesc_&) + { + + constexpr auto b1_wave_desc = [&]() { + // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 + constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_LRow = I2; +#else + constexpr auto B_LRow = I1; +#endif + return transform_tensor_descriptor( + B1BlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + }(); + + return b1_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + const index_t gemm0_bytes_end = + (SharedMemTrait::a_block_space_size_aligned * sizeof(ADataType) + + SharedMemTrait::b0_block_space_size_aligned * sizeof(B0DataType)); + + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + + SharedMemTrait::b1_block_space_size_aligned * sizeof(B1DataType)); + + const index_t acc0_bytes_end = + SharedMemTrait::reduction_space_offset + + SharedMemTrait::reduction_space_size_aligned * sizeof(Acc0DataType); + + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(CShuffleDataType); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, acc0_bytes_end, c_block_bytes_end); + } + + // Blockwise gemm pipeline for gemm0, this replaces the old GridwiseGemmPipe + + // BlockwiseGemmWMMA. The latter had two enableLDS bools which we don't + // have anymore, the new pipelines ALWAYS use lds. Furthermore the original BlockwiseGemmWMMA + // used TransposeC = true which we still need to make the operation work. + using BlockwiseGemmPipe = + remove_cvref_t())>; // TransposeC (must be true to work), C' = B' x A' + + // block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const B0GridDesc& b0_grid_desc, + const B1GridDesc& b1_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // Print lambda with env check and printf() style formmating. + const char* curFunc = __func__; + auto print = [&curFunc](const char* format, ...) -> void { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" +#endif + va_list args; + va_start(args, format); + std::vfprintf(stdout, format, args); + va_end(args); +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; + } + }; + + static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 && + LPerBlock % (LPerWmma * LRepeat) == 0 && + NPerBlock % (NPerWmma * NRepeat) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc.GetLength(I1); + const auto L = b0_grid_desc.GetLength(I1); + const auto K = a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + const auto N = b1_grid_desc.GetLength(I1); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + { + print("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", + M, + N, + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + return false; + } + + if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) + { + print("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = " + "%d, %d, %d, %d\n", + M, + L, + K, + N, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock); + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(LPerBlock % LTilePerBlock == 0)) + { + print("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", + LPerBlock, + LTilePerBlock); + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + print("GridwiseOp: invalid block_2_ctile_map\n"); + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = math::integer_divide_ceil(K, KPerBlock); + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = math::integer_divide_ceil(K, KPerBlock); + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1); + + static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + MakeABlockDescriptor().GetElementSpaceSize(), max_lds_align); + static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple( + MakeB0BlockDescriptor().GetElementSpaceSize(), max_lds_align); + static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( + MakeB1BlockDescriptor().GetElementSpaceSize(), max_lds_align); + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b0_block_space_offset = a_block_space_size_aligned; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for reduction + // Feature to add, IntraThread Reduction + static constexpr index_t reduction_space_size_aligned = + math::integer_least_multiple(BlockSize, max_lds_align); + + static constexpr auto reduction_space_offset = 0; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + }; + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc& a_grid_desc, + const B0GridDesc& b0_grid_desc, + const B1GridDesc& b1_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const B0ElementwiseOperation& b0_element_op, + const AccElementwiseOperation& acc_element_op, + const B1ElementwiseOperation& b1_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc.GetElementSpaceSize()); + const auto b0_grid_buf = make_dynamic_buffer( + p_b0_grid, b0_grid_desc.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// set up Gemm0 +/*******************************************************************************/ + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destination of BlockWise_Copy + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b0_block_desc = MakeB0BlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + constexpr auto AK0PerBlock = KPerBlock/ AK1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ AElementwiseOperation, +/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(a_block_buf, a_blockwise_copy); + }; + + auto b0_block_trait = [&](){ + auto b0_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b0_block_space_offset, + SharedMemTrait::b0_block_space_size_aligned); + + auto b0_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0DataType, + B0DataType, + decltype(b0_grid_desc), + decltype(b0_block_desc), + B0BlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + B0BlockTransferSrcVectorDim, + 2, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + 1, + 1, + B0ThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b0_grid_desc, + make_multi_index(0, 0, 0), + b0_element_op, + b0_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b0_block_buf, b0_blockwise_copy); + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b0_block_buf = b0_block_trait()[I0]; + auto b0_blockwise_copy = b0_block_trait()[I1]; + +/*******************************************************************************/ + // Gemm0 + // Blockwise GEMM0 pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm0_pipeline = BlockwiseGemmPipe{}; + auto acc0_thread_buf = blockwise_gemm0_pipeline.GetCThreadBuffer(); + + // Note that we are using the transposeC version of GetCThreadDescriptor. + constexpr auto acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm0_pipeline.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + constexpr auto mrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0); + constexpr auto mwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1); + constexpr auto mthreadpersubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2); + constexpr auto lrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3); + constexpr auto lwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4); + constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); + constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); + + constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor( + acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)), + make_merge_transform_v3_division_mod(make_tuple(mrepeat, mwave, mthreadpersubgroup)), + make_pass_through_transform(laccvgprs)), + make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + +/*******************************************************************************/ + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b0_block_slice_copy_step = MakeB0BlockSliceCopyStep(); + + const auto a_block_reset_copy_step = [&](){ + return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0); + }(); + + const auto b0_block_reset_copy_step = [&](){ + return make_multi_index(-b0_grid_desc.GetLength(I0), LPerBlock, 0); + }(); + + const auto K = [&](){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + }(); + + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + +/*******************************************************************************/ +// set up Gemm1 +/*******************************************************************************/ + // Acc0 thread buffer -> A1 thread buffer -> blockwise gemm + // A1 matrix in VGPR + constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple( + Number{}, + Number{}, + Number{}); + + constexpr auto A1ThreadSliceL0PerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I0]; + constexpr auto A1ThreadSliceMPerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I1]; + constexpr auto A1ThreadSliceL1 = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I2]; + + constexpr auto a1_thread_desc_l0perblock_mperblock_l1 = make_naive_tensor_descriptor( + make_tuple(A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadSliceL1), + make_tuple(A1ThreadSliceMPerBlock * A1ThreadSliceL1, A1ThreadSliceL1, I1)); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + Acc0DataType, + ADataType, + decltype(acc0_thread_desc_l0perblock_mperblock_l1), + decltype(a1_thread_desc_l0perblock_mperblock_l1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2>, + 2, + laccvgprs>{tensor_operation::element_wise::PassThrough{}}; + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize()); + + constexpr auto b1_block_desc = MakeB1BlockDescriptor(); + + auto b1_block_trait = [&](){ + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + SharedMemTrait::b1_block_space_size_aligned); + + auto b1_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, +/* typename SrcElementwiseOperation, */ B1ElementwiseOperation, +/* typename DstElementwiseOperation, */ tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ B1BlockTransferThreadClusterLengths_L0_N_L1, +/* typename ThreadClusterArrangeOrder, */ B1BlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ B1DataType, +/* typename DstData, */ B1DataType, +/* typename SrcDesc, */ decltype(b1_grid_desc), +/* typename DstDesc, */ decltype(b1_block_desc), +/* typename SrcDimAccessOrder, */ B1BlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>, +/* index_t SrcVectorDim, */ B1BlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ B1BlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ B1BlockTransferDstScalarPerVector_L1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ B1ThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, // DstResetCoord + 1>( // Used to be NumGemmKPrefetchStage, never tested / used for != 1 + b1_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b1_block_buf, b1_blockwise_copy); + }; + + auto b1_block_buf = b1_block_trait()[I0]; + auto b1_blockwise_copy = b1_block_trait()[I1]; + + constexpr auto b1_block_slice_copy_step = MakeB1BlockSliceCopyStep(); + + auto blockwise_gemm1 = + BlockwiseGemmWMMA{make_tuple(0, 0, 0, 0, 0, 0)}; + + auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); + + const auto L = [&](){ + return b0_grid_desc.GetLength(I1); + }(); + + const index_t num_gemm1_l_block_outer_loop = L / LPerBlock; + constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / LTilePerBlock; + + // Initialize C + StaticBuffer c_thread_buf; + c_thread_buf.Clear(); + + // Empty BScale struct for the blockwise pipeline. + using BScale = typename BlockwiseGemmPipe::Empty; + auto b_scale_struct = BScale{}; + +/*******************************************************************************/ + // + // Kernel Main Stage + // + index_t gemm1_l_block_outer_index = 0; + // Outer loop, along GEMM_L + // Inner loop, along GEMM_K + do { + blockwise_gemm0_pipeline.template Run(a_grid_desc, + a_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b0_grid_desc, + b0_block_desc, + b0_blockwise_copy, + b0_grid_buf, + b0_block_buf, + b0_block_slice_copy_step, + acc0_thread_buf, + b_scale_struct, + KBlockMainLoop, + 1); // num_k_block_per_scale + + static_for<0, acc0_thread_buf.Size(), 1>{}( + [&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); }); + + block_sync_lds(); + + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // Initialize acc1 + acc1_thread_buf.Clear(); + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for reduction LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf); + + // main body + if constexpr(num_gemm1_l_block_inner_loop > 1) + { + static_for<0, num_gemm1_l_block_inner_loop - 1, 1>{}([&](auto i) { + // Data cast from Acc0DataType to ADataType happens here + a1_blockwise_copy.Run(acc0_thread_desc_l0perblock_mperblock_l1, + make_tuple(Number{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_l0perblock_mperblock_l1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc0_thread_desc_l0perblock_mperblock_l1, + make_tuple( + Number<(num_gemm1_l_block_inner_loop - 1) * A1ThreadSliceL0PerBlock>{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_l0perblock_mperblock_l1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + } + } // end gemm1 + + constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0); + constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1); + constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2); + constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3); + constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4); + constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); + constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); + + constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup, + c_nrepeat * c_nwave * c_nsubgroup * c_naccvgprs)); + constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); + constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); + + static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { + static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { + auto I = Number{}; + Acc1DataType acc1 = acc1_thread_buf[I]; // P*V + Acc1DataType c = c_thread_buf[I]; // O + Acc1DataType c_new = c + acc1; // Simply add results since we are no longer using softmax. + + c_thread_buf(I) = c_new; // O_new + }); + }); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, + a_block_reset_copy_step); // rewind K + b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc, + b0_block_reset_copy_step); // rewind K and step N + + block_sync_lds(); // wait for gemm1 LDS read + }while(++gemm1_l_block_outer_index < num_gemm1_l_block_outer_loop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + // This API Provide All dimension (size) you need + constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp = + blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1); + constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4); + constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5); + constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + + constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MThreadPerSubGroup // MThreadPerSubGroup = MPerWmma + )), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NSubGroup, + NAccVgprs))), // NSubGroup * NAccVgprs = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2>{}, Sequence<>{}, Sequence<3, 4, 5, 6>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NSubGroup, NAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 8, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 64327d0142..8e53728ef6 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -243,6 +243,30 @@ inline __host__ __device__ constexpr half_t type_convert_sp(int x) return u.fp16; } +template <> +inline __host__ __device__ constexpr int type_convert_sp(f8_t x) +{ + union + { + f8_t fp8; + int int32; + } u = {x}; + + return u.int32; +} + +template <> +inline __host__ __device__ constexpr f8_t type_convert_sp(int x) +{ + union + { + int int32; + f8_t fp8; + } u = {x}; + + return u.fp8; +} + template <> inline __host__ __device__ constexpr int type_convert_sp(bhalf_t x) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp index 42ca8e755d..95d837f235 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,6 +16,70 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_BF16 +void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances); +void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance( + std::vector>>& instances); +#endif // CK_ENABLE_BF16 +#ifdef CK_ENABLE_FP16 +void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances); +void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( + std::vector>>& instances); +#endif // CK_ENABLE_FP16 +#endif // CK_USE_WMMA +#ifdef CK_USE_XDL #ifdef CK_ENABLE_FP16 void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( std::vector>>& instances); +#endif // CK_ENABLE_FP16 +#endif // CK_USE_XDL template > op_ptrs; - +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance( + op_ptrs); + } + } +#endif // CK_ENABLE_BF16 +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( + op_ptrs); + } + } +#endif // CK_ENABLE_FP16 +#endif // CK_USE_WMMA +#ifdef CK_USE_XDL +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -103,10 +208,11 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif // CK_ENABLE_FP16 +#endif // CK_USE_XDL return op_ptrs; } }; -#endif } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt index 2aa607429d..527c40fcd9 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt @@ -1,5 +1,9 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_batched_gemm_gemm_instance + device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp + device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance.cpp + device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp + device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp new file mode 100644 index 0000000000..d6bea4d36c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1; +static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3; + +// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n] +// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o] +// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k" +// dimensions instead! +template +using device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances = + std:: + tuple< + // clang-format off + //################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer| + //################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer> + // clang-format on + >; + +void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances) +{ + // clang-format off + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances{}); + // clang-format on +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance.cpp new file mode 100644 index 0000000000..3a38ab0d68 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance.cpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1; +static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3; + +// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n] +// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o] +// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k" +// dimensions instead! +template +using device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances = + std:: + tuple< + // clang-format off + //################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer| + //################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer> + // clang-format on + >; + +void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance( + std::vector>>& instances) +{ + // clang-format off + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances{}); + // clang-format on +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp new file mode 100644 index 0000000000..9155b14da8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1; +static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3; + +// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n] +// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o] +// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k" +// dimensions instead! +template +using device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = + std:: + tuple< + // clang-format off + //################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer| + //################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer> + // clang-format on + >; + +void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances) +{ + // clang-format off + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); + // clang-format on +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp new file mode 100644 index 0000000000..bd7bcd1be6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1; +static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3; + +// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n] +// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o] +// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k" +// dimensions instead! +template +using device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances = + std:: + tuple< + // clang-format off + //################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer| + //################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>, + DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer> + // clang-format on + >; + +void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( + std::vector>>& instances) +{ + // clang-format off + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{}); + add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{}); + // clang-format on +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp index b585b7d56a..8089f9efc7 100644 --- a/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -220,9 +220,10 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, } std::string best_op_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + int num_supported_instances = 0; // profile device op instances for(auto& op_ptr : op_ptrs) @@ -255,6 +256,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + num_supported_instances++; std::string op_name = op_ptr->GetTypeString(); float ave_time = @@ -309,6 +311,8 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, } } + printf("\033[36mFound %d supported instances\033[0m\n", num_supported_instances); + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 4700a34e9d..07beff945b 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -41,7 +41,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") endif() if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_OPS profile_gemm_reduce.cpp) - list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp) list(APPEND PROFILER_OPS profile_gemm_add.cpp) list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp) @@ -98,6 +97,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12 list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) + list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp) endif() if(DL_KERNELS) @@ -155,7 +155,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_gemm_add_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance) - list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance) list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance) list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance) @@ -219,6 +218,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) + list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance) endif() if(DL_KERNELS) diff --git a/test/batched_gemm_gemm/CMakeLists.txt b/test/batched_gemm_gemm/CMakeLists.txt index 2b3288ef9d..70d7420992 100644 --- a/test/batched_gemm_gemm/CMakeLists.txt +++ b/test/batched_gemm_gemm/CMakeLists.txt @@ -1,6 +1,14 @@ -add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16_xdl.cpp) +add_gtest_executable(test_batched_gemm_gemm_fp16_xdl test_batched_gemm_gemm_fp16_xdl.cpp) if(result EQUAL 0) - add_custom_target(test_batched_gemm_gemm) - target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) - add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) + target_link_libraries(test_batched_gemm_gemm_fp16_xdl PRIVATE utility device_batched_gemm_gemm_instance) +endif() + +add_gtest_executable(test_batched_gemm_gemm_bf16_wmma test_batched_gemm_gemm_bf16_wmma_cshuffle_v3.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_gemm_bf16_wmma PRIVATE utility device_batched_gemm_gemm_instance) +endif() + +add_gtest_executable(test_batched_gemm_gemm_fp16_wmma test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_gemm_fp16_wmma PRIVATE utility device_batched_gemm_gemm_instance) endif() diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_bf16_wmma_cshuffle_v3.cpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_bf16_wmma_cshuffle_v3.cpp new file mode 100644 index 0000000000..41c375a624 --- /dev/null +++ b/test/batched_gemm_gemm/test_batched_gemm_gemm_bf16_wmma_cshuffle_v3.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_batched_gemm_gemm_util.hpp" + +template +class TestBatchedGemmGemmBF16 : public TestBatchedGemmGemm +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, + std::tuple + >; +// clang-format on + +TYPED_TEST_SUITE(TestBatchedGemmGemmBF16, KernelTypes); + +TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16) +{ + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadM) +{ + this->lengths_ = std::vector>{ + {136, 128, 32, 128, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadN) +{ + this->lengths_ = std::vector>{ + {128, 136, 32, 128, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadK) +{ + this->lengths_ = std::vector>{ + {128, 128, 40, 128, 1}, + {128, 128, 136, 128, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 136, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddM) +{ + this->lengths_ = std::vector>{ + {129, 128, 32, 128, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddN) +{ + this->lengths_ = std::vector>{ + {128, 129, 32, 128, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddK) +{ + this->lengths_ = std::vector>{ + {128, 128, 33, 128, 1}, + {128, 128, 129, 128, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +// If kernel B1Layout is RowMajor, expect not to support odd O size +TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 129, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmBF16, DISABLED_Bench_BF16) +{ + this->lengths_ = std::vector>{ + {256, 256, 64, 64, 768}, + {256, 256, 128, 128, 768}, + {512, 512, 64, 64, 768}, + {512, 512, 128, 128, 768}, + {1024, 1024, 64, 64, 768}, + {1024, 1024, 128, 128, 768}, + {2048, 2048, 64, 64, 768}, + {2048, 2048, 128, 128, 768}, + {4096, 4096, 64, 64, 768}, + {4096, 4096, 128, 128, 768}, + }; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp new file mode 100644 index 0000000000..4dd55429b4 --- /dev/null +++ b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_batched_gemm_gemm_util.hpp" + +template +class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, + std::tuple + >; +// clang-format on + +TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes); + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16) +{ + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadM) +{ + this->lengths_ = std::vector>{ + {136, 128, 32, 128, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadN) +{ + this->lengths_ = std::vector>{ + {128, 136, 32, 128, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadK) +{ + this->lengths_ = std::vector>{ + {128, 128, 40, 128, 1}, + {128, 128, 136, 128, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 136, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddM) +{ + this->lengths_ = std::vector>{ + {129, 128, 32, 128, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN) +{ + this->lengths_ = std::vector>{ + {128, 129, 32, 128, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK) +{ + this->lengths_ = std::vector>{ + {128, 128, 33, 128, 1}, + {128, 128, 129, 128, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +// If kernel B1Layout is RowMajor, expect not to support odd O size +TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 129, 1}, + }; + this->bench_ = true; + this->verify_ = true; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16) +{ + this->lengths_ = std::vector>{ + {256, 256, 64, 64, 768}, + {256, 256, 128, 128, 768}, + {512, 512, 64, 64, 768}, + {512, 512, 128, 128, 768}, + {1024, 1024, 64, 64, 768}, + {1024, 1024, 128, 128, 768}, + {2048, 2048, 64, 64, 768}, + {2048, 2048, 128, 128, 768}, + {4096, 4096, 64, 64, 768}, + {4096, 4096, 128, 128, 768}, + }; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp index 1a8d5c2e55..b9a41a09c8 100644 --- a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp +++ b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp @@ -1,8 +1,126 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_gemm_util.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" + +template +struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using ALayout = Row; + using B0Layout = Col; + using B1Layout = Row; + using CLayout = Row; + + using ADataType = F16; + using B0DataType = F16; + using B1DataType = F16; + using AccDataType = float; + using CShuffleDataType = float; + using CDataType = F16; + + using AElementOp = PassThrough; + using B0ElementOp = PassThrough; + using Acc0ElementOp = PassThrough; + using B1ElementOp = PassThrough; + using CElementOp = PassThrough; + + template + using S = ck::Sequence; + + // static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value; + + using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< + ALayout, + B0Layout, + B1Layout, + CLayout, + ADataType, + B0DataType, + B1DataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + + bool IsSupported(int M, int N, int K, int O) + { + auto gemm = DeviceGemmGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + M, + N, + K, + O, + 0, // BatchCount + 0, // StrideA + 0, // StrideB0 + 0, // StrideB1 + 0, // StrideC + 0, // BatchStrideA + 0, // BatchStrideB0 + 0, // BatchStrideB1 + 0, // BatchStrideC + PassThrough{}, // a_element_op + PassThrough{}, // b0_element_op + PassThrough{}, // acc0_element_op + PassThrough{}, // b1_element_op + PassThrough{}); // c_element_op + + return gemm.IsSupportedArgument(argument); + } +}; template class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp index b0fffc466e..b538f9fb0a 100644 --- a/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp +++ b/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp @@ -1,11 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp" #include "profiler/profile_batched_gemm_gemm_impl.hpp" using ck::tensor_operation::device::GemmSpecialization; @@ -13,7 +12,8 @@ using ck::tensor_operation::device::GemmSpecialization; template using I = ck::Number; -using F16 = ck::half_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -70,120 +70,3 @@ struct TestBatchedGemmGemm : public ::testing::Test } } }; - -template -struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 -{ - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - - using ALayout = Row; - using B0Layout = Col; - using B1Layout = Row; - using CLayout = Row; - - using ADataType = F16; - using B0DataType = F16; - using B1DataType = F16; - using AccDataType = float; - using CShuffleDataType = float; - using CDataType = F16; - - using AElementOp = PassThrough; - using B0ElementOp = PassThrough; - using Acc0ElementOp = PassThrough; - using B1ElementOp = PassThrough; - using CElementOp = PassThrough; - - template - using S = ck::Sequence; - - // static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value; - - using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle< - ALayout, - B0Layout, - B1Layout, - CLayout, - ADataType, - B0DataType, - B1DataType, - CDataType, - AccDataType, - CShuffleDataType, - AElementOp, - B0ElementOp, - Acc0ElementOp, - B1ElementOp, - CElementOp, - GemmSpec, - 1, - 256, - 128, // MPerBlock - 128, // NPerBlock - 32, // KPerBlock - 128, // Gemm1NPerBlock - 32, // Gemm1KPerBlock - 8, // AK1 - 8, // BK1 - 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 4, // Gemm1NXdlPerWave - S<4, 64, 1>, // ABlockTransfer - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 64, 1>, // BBlockTransfer - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<8, 32, 1>, // B1BlockTransfer - S<0, 2, 1>, - S<0, 2, 1>, - 1, - 4, - 2, - false, - 1, // CShuffleMXdlPerWavePerShuffle - 2, // CShuffleNXdlPerWavePerShuffle - S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock - - bool IsSupported(int M, int N, int K, int O) - { - auto gemm = DeviceGemmGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(nullptr), - static_cast(nullptr), - static_cast(nullptr), - static_cast(nullptr), - M, - N, - K, - O, - 0, // BatchCount - 0, // StrideA - 0, // StrideB0 - 0, // StrideB1 - 0, // StrideC - 0, // BatchStrideA - 0, // BatchStrideB0 - 0, // BatchStrideB1 - 0, // BatchStrideC - PassThrough{}, // a_element_op - PassThrough{}, // b0_element_op - PassThrough{}, // acc0_element_op - PassThrough{}, // b1_element_op - PassThrough{}); // c_element_op - - return gemm.IsSupportedArgument(argument); - } -};