Merge commit '7330ec37ee3b8cf2d54630372dfe9e86a893e4f5' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-04 21:11:23 +00:00
parent 5677205f88
commit 7f65be1b3e
51 changed files with 3709 additions and 189 deletions

View File

@@ -1,6 +1,12 @@
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_bf16 batched_gemm_gemm_wmma_cshuffle_v3_bf16.cpp)
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp8 batched_gemm_gemm_wmma_cshuffle_v3_fp8.cpp)
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp16 batched_gemm_gemm_wmma_cshuffle_v3_fp16.cpp)
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_int8 batched_gemm_gemm_wmma_cshuffle_v3_int8.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)

View File

@@ -0,0 +1,276 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Gemm fused operation. Computes C_g_m_n = (A_g_m_k * B0_g_k_l) * B1_g_l_n
|------------------|
Gemm0
|-----------------------------|
Gemm1
*/
static constexpr auto PipeSched = ck::BlockGemmPipelineScheduler::Interwave;
static constexpr auto PipelineVer = ck::BlockGemmPipelineVersion::v1;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
// clang-format off
// #define CK_MHA_USE_RCCR_LAYOUT
#define CK_MHA_USE_WAVE_1
// #define CK_MHA_USE_WAVE_2
// #define CK_MHA_USE_WAVE_4
// #define CK_MHA_USE_WAVE_8
#ifdef CK_MHA_USE_RCCR_LAYOUT
using DeviceMHAFactory =
std::tuple<
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Col, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
32,
// Gemm 0
16, 64, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
PipeSched, PipelineVer>
>;
#else
using DeviceMHAFactory =
std::tuple<
#ifdef CK_MHA_USE_WAVE_1
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
32,
// Gemm 0
16, 128, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
PipeSched, PipelineVer>,
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
32,
// Gemm 0
16, 64, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
PipeSched, PipelineVer>
#endif
#ifdef CK_MHA_USE_WAVE_2
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
64,
// Gemm 0
32, 128, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
PipeSched, PipelineVer>,
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
64,
// Gemm 0
32, 64, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
PipeSched, PipelineVer>
#endif
#ifdef CK_MHA_USE_WAVE_4
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
128,
// Gemm 0
64, 128, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
PipeSched, PipelineVer>,
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
128,
// Gemm 0
64, 64, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
PipeSched, PipelineVer>
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
256,
// Gemm 0
128, 128, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
PipeSched, PipelineVer>,
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
256,
// Gemm 0
128, 128, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
PipeSched, PipelineVer>
#endif
>;
#endif
// clang-format on
// Ref Gemm0
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
AccDataType,
AccDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Gemm1
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
#include "run_batched_gemm_gemm_wmma_cshuffle_v3.inc"
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported() || ck::is_gfx12_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}

View File

@@ -0,0 +1,37 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
using BF16 = ck::bhalf_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = BF16;
using B0DataType = BF16;
using B1DataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = BF16;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"

View File

@@ -0,0 +1,37 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"

View File

@@ -0,0 +1,34 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::f8_t;
using B0DataType = ck::f8_t;
using B1DataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = ck::f8_t;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"

View File

@@ -0,0 +1,34 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = int8_t;
using B0DataType = int8_t;
using B1DataType = int8_t;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
using CDataType = int8_t;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"

View File

@@ -0,0 +1,304 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 113;
#ifdef CK_MHA_USE_RCCR_LAYOUT
ck::index_t N = 480; // Must be multiple of 8 even with padding.
#else
ck::index_t N = 477;
#endif
ck::index_t K = 200; // Must be multiple of 8 even with padding.
ck::index_t O = 208; // Must be multiple of 8 even with padding.
ck::index_t G = 91; // Batch
float alpha = 1;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G = std::stoi(argv[8]);
alpha = std::stof(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 8: M, N, K, O, G\n");
printf("arg9: scale (alpha)\n");
exit(0);
}
std::vector<ck::index_t> a_g_m_k_lengths{G, M, K};
std::vector<ck::index_t> a_g_m_k_strides{M * K, K, 1}; // A layout [G, M, K]
std::vector<ck::index_t> b0_g_n_k_lengths{G, N, K};
std::vector<ck::index_t> b0_g_n_k_strides{N * K, K, 1}; // B0 layout [G, N, K]
std::vector<ck::index_t> b1_g_o_n_lengths{G, O, N};
#ifdef CK_MHA_USE_RCCR_LAYOUT
std::vector<ck::index_t> b1_g_o_n_strides{N * O, N, 1}; // B1 layout [G, O, N]
#else
std::vector<ck::index_t> b1_g_o_n_strides{N * O, 1, O}; // B1 layout [G, N, O]
#endif
std::vector<ck::index_t> c_g_m_o_lengths{G, M, O};
std::vector<ck::index_t> c_g_m_o_strides{M * O, O, 1}; // C layout [G, M, O]
Tensor<ADataType> a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides);
Tensor<B0DataType> b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides);
Tensor<B1DataType> b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides);
Tensor<CDataType> c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides);
Tensor<CDataType> c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides);
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_n_k: " << b0_g_n_k.mDesc << std::endl;
std::cout << "b1_g_o_n: " << b1_g_o_n.mDesc << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 2:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
case 3:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
case 4: // A, B0, B1 1
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5: // Rand: b1 b0; unit: a
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 6: // Rand: a b0 ; unit: B1
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 7: // Rand: a b1 ; unit: b0
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 8: // Rand: a ; unit: b0 b1
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 9: // Rand: b0 ; unit: a b1
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 10: // Rand: b1 ; unit: a b0
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_g_n_k.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_g_o_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b0_device_buf.ToDevice(b0_g_n_k.mData.data());
b1_device_buf.ToDevice(b1_g_o_n.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
float best_perf = .0;
float best_time = .0;
int not_pass = 0;
std::string best_kernel = "";
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{};
auto invoker_ptr = gemm.MakeInvokerPointer();
auto argument_ptr =
gemm.MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
G, // Batch,
a_g_m_k_strides[1], // StrideA,
b0_g_n_k_strides[1], // StrideB0,
#ifdef CK_MHA_USE_RCCR_LAYOUT
b1_g_o_n_strides[1], // StrideB1,
#else
b1_g_o_n_strides[2], // StrideB1,
#endif
c_g_m_o_strides[1], // StrideC,
a_g_m_k_strides[0], // BatchStrideA
b0_g_n_k_strides[0], // BatchStrideB0
b1_g_o_n_strides[0], // BatchStrideB1
c_g_m_o_strides[0], // BatchStrideC
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument_ptr.get()))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return;
}
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
G;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
if(tflops > best_perf)
{
best_perf = tflops;
best_time = ave_time * 1000;
best_kernel = gemm.GetTypeString();
}
if(do_verification)
{
c_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
Tensor<B0DataType> b0_g_k_n({G, K, N});
Tensor<B1DataType> b1_g_n_o({G, N, O});
Tensor<AccDataType> acc0_g_m_n({G, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({G, M, N}); // scratch object after conversion
// permute
b0_g_n_k.ForEach(
[&](auto& self, auto idx) { b0_g_k_n(idx[0], idx[2], idx[1]) = self(idx); });
b1_g_o_n.ForEach(
[&](auto& self, auto idx) { b1_g_n_o(idx[0], idx[2], idx[1]) = self(idx); });
// gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
// Passthrough instead of softmax, DOES involve data type conversion.
a1_g_m_n(idx) = ck::type_convert<ADataType, AccDataType>(self(idx));
});
// gemm1
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
b1_g_n_o,
c_g_m_o_host_result,
PassThrough{},
b1_element_op,
c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
// default absolute error and relative error is 0.001
double rtol = 1e-3;
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
}
bool this_run_verification = ck::utils::check_err(c_g_m_o_device_result.mData,
c_g_m_o_host_result.mData,
"Error: Incorrect results!",
rtol,
atol);
printf("Verification: %s, Pass: %s\n",
do_verification ? "ON" : "OFF",
this_run_verification ? "YES" : "NO");
if(!this_run_verification)
{
not_pass = 1;
printf("%d th MHA instance verification Failed \n", i.value);
}
}
});
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Problem Size: G: " << G << ", M: " << M << ", N: " << N << ", K: " << K
<< ", O: " << O << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
<< " us" << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
return not_pass;
}

View File

@@ -1,7 +1,6 @@
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/library/include
${PROJECT_SOURCE_DIR}/example/include
)
add_custom_target(examples)

View File

@@ -5,7 +5,7 @@
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "utils.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <array>
#include <cstring>

View File

@@ -7,7 +7,7 @@
#include "mask.hpp"
#include "rotary.hpp"
#include "utils.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <array>
#include <cstring>

View File

@@ -1,6 +1,6 @@
#include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <algorithm>
#include <cstring>

View File

@@ -9,7 +9,7 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2

View File

@@ -3,7 +3,7 @@
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <cstring>
template <typename T>

View File

@@ -13,7 +13,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "topk_softmax_api.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#if 0
template <typename T>

View File

@@ -1,7 +1,7 @@
#include "ck_tile/host.hpp"
#include "rmsnorm2d_fwd.hpp"
#include <cstring>
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
// different threshold for different dtype
template <typename DataType>

View File

@@ -1,7 +1,7 @@
#include "ck_tile/host.hpp"
#include "add_rmsnorm2d_rdquant_fwd.hpp"
#include <cstring>
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
// different threshold for different dtype
template <typename InputDataType>

View File

@@ -1,6 +1,6 @@
#include "ck_tile/host.hpp"
#include "smoothquant.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <cstring>
// different threshold for different dtype

View File

@@ -14,7 +14,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "moe_sorting_api.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
auto create_args(int argc, char* argv[])
{

View File

@@ -1,6 +1,6 @@
#include "ck_tile/host.hpp"
#include "moe_smoothquant.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <cstring>
#include <set>

View File

@@ -5,7 +5,7 @@
#include <set>
#include "ck_tile/host.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include "fused_moe.hpp"
// different threshold for different dtype

View File

@@ -9,7 +9,7 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include <json_dump.hpp>
#include "ck_tile/utility/json_dump.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2

View File

@@ -9,7 +9,7 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2

View File

@@ -2,7 +2,7 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <type_traits>
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
template <typename T>
constexpr const char* DataTypeToString()
{

View File

@@ -3,7 +3,7 @@
#pragma once
#include <cstddef>
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
template <typename ADataType,
typename BDataType,

View File

@@ -4,7 +4,7 @@
#include "ck_tile/host.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include "elementwise_common.hpp"
auto create_args(int argc, char* argv[])

View File

@@ -4,7 +4,7 @@
#include "ck_tile/host.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include "elementwise_common.hpp"
auto create_args(int argc, char* argv[])

View File

@@ -4,7 +4,7 @@
#include "ck_tile/host.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/reference/reference_transpose.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include "elementwise_common.hpp"
auto create_args(int argc, char* argv[])

View File

@@ -4,7 +4,7 @@
#include "ck_tile/host.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include "elementwise_common.hpp"
auto create_args(int argc, char* argv[])

View File

@@ -12,7 +12,7 @@
#include "batched_transpose_example.hpp"
#include "json_dump.hpp"
#include "ck_tile/utility/json_dump.hpp"
#if 0
template <typename T>
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)

View File

@@ -1,700 +0,0 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant"
#include "rapidjson/writer.h"
#include "rapidjson/stringbuffer.h"
#include "rapidjson/document.h"
#include "rapidjson/rapidjson.h"
// #include <fstream>
#pragma GCC diagnostic pop
#define START_JSON_DUMP_FILE(file_name) \
std::string file_str(file_name); \
std::ofstream file(file_str); \
if(!file.is_open()) \
{ \
throw std::runtime_error("Could not open file: " + std::string(file_name)); \
} \
rapidjson::StringBuffer s; \
rapidjson::Writer<rapidjson::StringBuffer> writer(s); \
writer.StartObject();
#define END_JSON_DUMP_FILE() \
writer.EndObject(); \
file << s.GetString(); \
file.close(); \
std::cout << "Results written to " << file_str << " successfully" << std::endl;
#define ADD_KEY_VALUE(key, value) add_key_value_pair(writer, key, value);
#define ADD_PERF_TO_JSON(_time, tflops, gbytes) add_perf_to_json(writer, _time, tflops, gbytes);
template <typename T>
void add_key_value_pair(rapidjson::Writer<rapidjson::StringBuffer>& writer,
const char* key,
T value)
{
writer.Key(key);
if constexpr(std::is_same<T, const char*>::value)
{
writer.String(value, static_cast<rapidjson::SizeType>(std::strlen(value)));
}
else if constexpr(std::is_same<T, std::string>::value)
{
writer.String(value.c_str(), static_cast<rapidjson::SizeType>(value.length()));
}
else if constexpr(std::is_floating_point<T>::value)
{
writer.Double(static_cast<double>(value));
}
else if constexpr(std::is_integral<T>::value)
{
writer.Int64(static_cast<int64_t>(value));
}
else
{
static_assert(std::is_same<T, const char*>::value || std::is_floating_point<T>::value ||
std::is_integral<T>::value,
"Unsupported type for JSON serialization");
}
}
static void add_perf_to_json(rapidjson::Writer<rapidjson::StringBuffer>& writer,
float time,
float tflops,
float gbytes)
{
std::string roster("perf");
writer.String(roster.c_str(), static_cast<rapidjson::SizeType>(roster.length()));
writer.StartArray();
writer.StartObject();
add_key_value_pair(writer, "time", time);
add_key_value_pair(writer, "tflops", tflops);
add_key_value_pair(writer, "gbytes", gbytes);
writer.EndObject();
writer.EndArray();
}
// Helper traits to check for static member existence
template <typename T, typename = void>
struct has_warp_tile_members : std::false_type
{
};
template <typename T>
struct has_warp_tile_members<
T,
std::void_t<decltype(T::M_Warp_Tile), decltype(T::N_Warp_Tile), decltype(T::K_Warp_Tile)>>
: std::true_type
{
};
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename GemmConfig,
template <typename>
typename DTypeTraits>
void dump_gemm_json_results(const std::string& json_filename,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
bool persistent,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "gemm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("M", M);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("K", K);
ADD_KEY_VALUE("stride_A", stride_A);
ADD_KEY_VALUE("stride_B", stride_B);
ADD_KEY_VALUE("stride_C", stride_C);
ADD_KEY_VALUE("A_layout", ALayout::name);
ADD_KEY_VALUE("B_layout", BLayout::name);
ADD_KEY_VALUE("C_layout", CLayout::name);
using TraitsADataType = DTypeTraits<ADataType>;
using TraitsBDataType = DTypeTraits<BDataType>;
using TraitsCDataType = DTypeTraits<CDataType>;
ADD_KEY_VALUE("A_type", TraitsADataType::name);
ADD_KEY_VALUE("B_type", TraitsBDataType::name);
ADD_KEY_VALUE("C_type", TraitsCDataType::name);
ADD_KEY_VALUE("structured_sparsity", GemmConfig::UseStructuredSparsity ? "on" : "off");
if constexpr(has_warp_tile_members<GemmConfig>::value)
{
ADD_KEY_VALUE("warp_tile",
std::to_string(GemmConfig::M_Warp_Tile) + "x" +
std::to_string(GemmConfig::N_Warp_Tile) + "x" +
std::to_string(GemmConfig::K_Warp_Tile));
}
ADD_KEY_VALUE("persistent", persistent ? "on" : "off");
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_batched_gemm_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int batch_stride_A,
int batch_stride_B,
int batch_stride_C,
int batch_count,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "batched_gemm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("op_name", op_name);
ADD_KEY_VALUE("M", M);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("K", K);
ADD_KEY_VALUE("stride_A", stride_A);
ADD_KEY_VALUE("stride_B", stride_B);
ADD_KEY_VALUE("stride_C", stride_C);
ADD_KEY_VALUE("batch_stride_A", batch_stride_A);
ADD_KEY_VALUE("batch_stride_B", batch_stride_B);
ADD_KEY_VALUE("batch_stride_C", batch_stride_C);
ADD_KEY_VALUE("batch_count", batch_count);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
template <typename ALayout, typename BLayout, typename CLayout>
void dump_grouped_gemm_json_results(const std::string& json_filename,
const std::string& op_name,
int group_count,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "grouped_gemm")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("op_name", op_name);
ADD_KEY_VALUE("group_count", group_count);
ADD_KEY_VALUE("A_layout", ALayout::name);
ADD_KEY_VALUE("B_layout", BLayout::name);
ADD_KEY_VALUE("C_layout", CLayout::name);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_flatmm_json_results(const std::string& json_filename,
const std::string& datatype,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int kbatch,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "flatmm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("DataType", datatype);
ADD_KEY_VALUE("M", M);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("K", K);
ADD_KEY_VALUE("StrideA", stride_A);
ADD_KEY_VALUE("StrideB", stride_B);
ADD_KEY_VALUE("StrideC", stride_C);
ADD_KEY_VALUE("kbatch", kbatch);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideD0,
int StrideD1,
int StrideE,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "gemm_multi_d_fp16")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("op_name", op_name);
ADD_KEY_VALUE("M", M);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("K", K);
ADD_KEY_VALUE("StrideA", StrideA);
ADD_KEY_VALUE("StrideB", StrideB);
ADD_KEY_VALUE("StrideD0", StrideD0);
ADD_KEY_VALUE("StrideD1", StrideD1);
ADD_KEY_VALUE("StrideE", StrideE);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_elementwise_json_results(const std::string& json_filename,
const std::string& prec,
int grid_size,
int block_size,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "elementwise")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", prec);
ADD_KEY_VALUE("grid_size", grid_size);
ADD_KEY_VALUE("block_size", block_size);
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
const std::string& prec_sm,
const std::string& prec_sy,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "layernorm2d_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec_i", prec_i);
ADD_KEY_VALUE("prec_o", prec_o);
ADD_KEY_VALUE("prec_sm", prec_sm);
ADD_KEY_VALUE("prec_sy", prec_sy);
ADD_KEY_VALUE("m", m);
ADD_KEY_VALUE("n", n);
ADD_KEY_VALUE("x_stride", x_stride);
ADD_KEY_VALUE("xr_stride", xr_stride);
ADD_KEY_VALUE("y_stride", y_stride);
ADD_KEY_VALUE("yr_stride", yr_stride);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
template <typename DataType, template <typename> typename DTypeTraits>
void dump_reduce_json_results(const std::string& json_filename,
int N,
int C,
int H,
int W,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "reduce")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
using Traits = DTypeTraits<DataType>;
ADD_KEY_VALUE("data_type", Traits::name);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("C", C);
ADD_KEY_VALUE("H", H);
ADD_KEY_VALUE("W", W);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_permute_json_results(const std::string& json_filename,
const std::string& data_type,
bool pass,
float ave_time,
float tflop,
float gb_per_sec,
const std::string& kernel_name = "permute")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("data_type", data_type);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_topk_softmax_json(const std::string& json_filename,
const std::string& input_prec,
const std::string& weight_prec,
int tokens,
int experts,
int topk,
int stride_input,
int stride_output,
float ave_time,
float tflop,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "topk_softmax")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("input_prec", input_prec);
ADD_KEY_VALUE("weight_prec", weight_prec);
ADD_KEY_VALUE("tokens", tokens);
ADD_KEY_VALUE("experts", experts);
ADD_KEY_VALUE("topk", topk);
ADD_KEY_VALUE("stride_input", stride_input);
ADD_KEY_VALUE("stride_output", stride_output);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
int use_model_sensitive_rmsnorm,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "rmsnorm2d_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", prec_str);
ADD_KEY_VALUE("m", m);
ADD_KEY_VALUE("n", n);
ADD_KEY_VALUE("x_stride", x_stride);
ADD_KEY_VALUE("xr_stride", xr_stride);
ADD_KEY_VALUE("y_stride", y_stride);
ADD_KEY_VALUE("yr_stride", yr_stride);
ADD_KEY_VALUE("use_model_sensitive_rmsnorm", use_model_sensitive_rmsnorm);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_add_rmsnorm2d_rdquant_fwd_json(
const std::string& json_filename,
const std::string& input_data_type,
const std::string& quantized_data_type,
int m,
int n,
int stride,
float epsilon,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("input_data_type", input_data_type);
ADD_KEY_VALUE("quantized_data_type", quantized_data_type);
ADD_KEY_VALUE("m", m);
ADD_KEY_VALUE("n", n);
ADD_KEY_VALUE("stride", stride);
ADD_KEY_VALUE("epsilon", epsilon);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_smoothquant_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int y_stride,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "smoothquant")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", prec_str);
ADD_KEY_VALUE("m", m);
ADD_KEY_VALUE("n", n);
ADD_KEY_VALUE("x_stride", x_stride);
ADD_KEY_VALUE("y_stride", y_stride);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_moe_sorting_json(const std::string& json_filename,
const std::string& index_prec,
const std::string& weight_prec,
const std::string& workspace_size,
int dispatch_policy,
int tokens,
int num_experts,
int topk,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "moe_sorting")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("index_prec", index_prec);
ADD_KEY_VALUE("weight_prec", weight_prec);
ADD_KEY_VALUE("workspace_size", workspace_size);
ADD_KEY_VALUE("dispatch_policy", dispatch_policy);
ADD_KEY_VALUE("tokens", tokens);
ADD_KEY_VALUE("num_experts", num_experts);
ADD_KEY_VALUE("topk", topk);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_batched_transpose_json(const std::string& json_filename,
int N,
int C,
int H,
int W,
const std::string& layout_in,
const std::string& layout_out,
const std::string& prec,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "batched_transpose")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("C", C);
ADD_KEY_VALUE("H", H);
ADD_KEY_VALUE("W", W);
ADD_KEY_VALUE("LayoutIn", layout_in);
ADD_KEY_VALUE("LayoutOut", layout_out);
ADD_KEY_VALUE("Precision", prec);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_moe_smoothquant_json(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
int tokens,
int hidden_size,
int stride,
int experts,
int topk,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "moe_smoothquant")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec_i", prec_i);
ADD_KEY_VALUE("prec_o", prec_o);
ADD_KEY_VALUE("tokens", tokens);
ADD_KEY_VALUE("hidden_size", hidden_size);
ADD_KEY_VALUE("stride", stride);
ADD_KEY_VALUE("experts", experts);
ADD_KEY_VALUE("topk", topk);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_fused_moe_json(const std::string& json_filename,
const std::string& api_str,
const std::string& prec_str,
int tokens,
bool is_local_token,
int local_tokens,
int experts,
int topk,
int hidden_size,
int intermediate_size,
int stride,
int block_m,
int activation,
bool gate_only,
bool fused_quant,
bool pass,
float ave_time,
float tflops,
float tb_per_sec,
const std::string& kernel_name = "fused_moe")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("api", api_str);
ADD_KEY_VALUE("prec", prec_str);
ADD_KEY_VALUE("tokens", tokens);
if(is_local_token)
{
ADD_KEY_VALUE("local_tokens", local_tokens);
}
ADD_KEY_VALUE("experts", experts);
ADD_KEY_VALUE("topk", topk);
ADD_KEY_VALUE("hidden_size", hidden_size);
ADD_KEY_VALUE("intermediate_size", intermediate_size);
ADD_KEY_VALUE("stride", stride);
ADD_KEY_VALUE("block_m", block_m);
ADD_KEY_VALUE("activation", activation);
ADD_KEY_VALUE("gate_only", gate_only);
ADD_KEY_VALUE("fused_quant", fused_quant);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, (tb_per_sec * 1024.0f))
END_JSON_DUMP_FILE();
}
void dump_fmha_fwd_json_results(const std::string& json_filename,
const std::string& prec,
const std::string& mode,
const std::string& io_layout,
int batch,
int nhead,
int nhead_k,
int seqlen_qs,
int seqlen_ks,
int seqlen_kpads,
int hdim_q,
int hdim_v,
float scale_s,
float p_drop,
bool lse,
bool squant,
const std::string& bais,
const std::string& vlayout,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", prec);
ADD_KEY_VALUE("mode", mode);
ADD_KEY_VALUE("io_layout", io_layout);
ADD_KEY_VALUE("batch", batch);
ADD_KEY_VALUE("nhead", nhead);
ADD_KEY_VALUE("nhead_k", nhead_k);
ADD_KEY_VALUE("seqlen_q", seqlen_qs);
ADD_KEY_VALUE("seqlen_k", seqlen_ks);
ADD_KEY_VALUE("seqlen_kpads", seqlen_kpads);
ADD_KEY_VALUE("hdim_q", hdim_q);
ADD_KEY_VALUE("hdim_v", hdim_v);
ADD_KEY_VALUE("scale_s", scale_s);
ADD_KEY_VALUE("p_drop", p_drop);
ADD_KEY_VALUE("lse", lse);
ADD_KEY_VALUE("squant", squant);
ADD_KEY_VALUE("bias", bais);
ADD_KEY_VALUE("vlayout", vlayout);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_fmha_bwd_json_results(const std::string& json_filename,
const std::string& data_type,
const std::string& mode,
const std::string& i_perm,
const std::string& o_perm,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale,
const std::string& bias,
bool use_dbias,
float p_drop,
bool s_randval,
bool deterministic,
const std::string& mask,
int mask_left,
int mask_right,
int workspace_size,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_bwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", data_type);
ADD_KEY_VALUE("mode", mode);
ADD_KEY_VALUE("i_perm", i_perm);
ADD_KEY_VALUE("o_perm", o_perm);
ADD_KEY_VALUE("batch", batch);
ADD_KEY_VALUE("nhead", nhead);
ADD_KEY_VALUE("nhead_k", nhead_k);
ADD_KEY_VALUE("seqlen_q", seqlen_q);
ADD_KEY_VALUE("seqlen_k", seqlen_k);
ADD_KEY_VALUE("hdim_q", hdim_q);
ADD_KEY_VALUE("hdim_v", hdim_v);
ADD_KEY_VALUE("scale", scale);
ADD_KEY_VALUE("bias", bias);
ADD_KEY_VALUE("use_dbias", use_dbias);
ADD_KEY_VALUE("p_drop", p_drop);
ADD_KEY_VALUE("s_randval", s_randval);
ADD_KEY_VALUE("deterministic", deterministic ? "true" : "false");
ADD_KEY_VALUE("mask", mask);
ADD_KEY_VALUE("mask_left", mask_left);
ADD_KEY_VALUE("mask_right", mask_right);
ADD_KEY_VALUE("workspace_size", workspace_size);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}