mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Merge commit '7330ec37ee3b8cf2d54630372dfe9e86a893e4f5' into develop
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_bf16 batched_gemm_gemm_wmma_cshuffle_v3_bf16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp8 batched_gemm_gemm_wmma_cshuffle_v3_fp8.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp16 batched_gemm_gemm_wmma_cshuffle_v3_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_int8 batched_gemm_gemm_wmma_cshuffle_v3_int8.cpp)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/*
|
||||
Gemm + Gemm fused operation. Computes C_g_m_n = (A_g_m_k * B0_g_k_l) * B1_g_l_n
|
||||
|------------------|
|
||||
Gemm0
|
||||
|-----------------------------|
|
||||
Gemm1
|
||||
*/
|
||||
|
||||
static constexpr auto PipeSched = ck::BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto PipelineVer = ck::BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// clang-format off
|
||||
// #define CK_MHA_USE_RCCR_LAYOUT
|
||||
#define CK_MHA_USE_WAVE_1
|
||||
// #define CK_MHA_USE_WAVE_2
|
||||
// #define CK_MHA_USE_WAVE_4
|
||||
// #define CK_MHA_USE_WAVE_8
|
||||
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
using DeviceMHAFactory =
|
||||
std::tuple<
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Col, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
// Gemm 0
|
||||
16, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 16, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
>;
|
||||
#else
|
||||
using DeviceMHAFactory =
|
||||
std::tuple<
|
||||
#ifdef CK_MHA_USE_WAVE_1
|
||||
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
// Gemm 0
|
||||
16, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 16, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
// Gemm 0
|
||||
16, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 16, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_2
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
64,
|
||||
// Gemm 0
|
||||
32, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 32, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
64,
|
||||
// Gemm 0
|
||||
32, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 32, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_4
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
128,
|
||||
// Gemm 0
|
||||
64, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 64, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
128,
|
||||
// Gemm 0
|
||||
64, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 64, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_8
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
256,
|
||||
// Gemm 0
|
||||
128, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 128, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
256,
|
||||
// Gemm 0
|
||||
128, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 128, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
>;
|
||||
#endif
|
||||
|
||||
// clang-format on
|
||||
// Ref Gemm0
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B0DataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp>;
|
||||
|
||||
// Ref Gemm1
|
||||
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_batched_gemm_gemm_wmma_cshuffle_v3.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported() || ck::is_gfx12_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return run(argc, argv);
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = BF16;
|
||||
using B0DataType = BF16;
|
||||
using B1DataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = BF16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,37 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,34 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using B0DataType = ck::f8_t;
|
||||
using B1DataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = ck::f8_t;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,34 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = int8_t;
|
||||
using B0DataType = int8_t;
|
||||
using B1DataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CShuffleDataType = int32_t;
|
||||
using CDataType = int8_t;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,304 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape for A/B0/B1/C
|
||||
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
|
||||
ck::index_t M = 113;
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
ck::index_t N = 480; // Must be multiple of 8 even with padding.
|
||||
#else
|
||||
ck::index_t N = 477;
|
||||
#endif
|
||||
ck::index_t K = 200; // Must be multiple of 8 even with padding.
|
||||
ck::index_t O = 208; // Must be multiple of 8 even with padding.
|
||||
ck::index_t G = 91; // Batch
|
||||
|
||||
float alpha = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
G = std::stoi(argv[8]);
|
||||
|
||||
alpha = std::stof(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 8: M, N, K, O, G\n");
|
||||
printf("arg9: scale (alpha)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> a_g_m_k_lengths{G, M, K};
|
||||
std::vector<ck::index_t> a_g_m_k_strides{M * K, K, 1}; // A layout [G, M, K]
|
||||
std::vector<ck::index_t> b0_g_n_k_lengths{G, N, K};
|
||||
std::vector<ck::index_t> b0_g_n_k_strides{N * K, K, 1}; // B0 layout [G, N, K]
|
||||
std::vector<ck::index_t> b1_g_o_n_lengths{G, O, N};
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
std::vector<ck::index_t> b1_g_o_n_strides{N * O, N, 1}; // B1 layout [G, O, N]
|
||||
#else
|
||||
std::vector<ck::index_t> b1_g_o_n_strides{N * O, 1, O}; // B1 layout [G, N, O]
|
||||
#endif
|
||||
std::vector<ck::index_t> c_g_m_o_lengths{G, M, O};
|
||||
std::vector<ck::index_t> c_g_m_o_strides{M * O, O, 1}; // C layout [G, M, O]
|
||||
|
||||
Tensor<ADataType> a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides);
|
||||
Tensor<B0DataType> b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides);
|
||||
Tensor<B1DataType> b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
Tensor<CDataType> c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
Tensor<CDataType> c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_n_k: " << b0_g_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_g_o_n: " << b1_g_o_n.mDesc << std::endl;
|
||||
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
case 4: // A, B0, B1 1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 5: // Rand: b1 b0; unit: a
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 6: // Rand: a b0 ; unit: B1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 7: // Rand: a b1 ; unit: b0
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 8: // Rand: a ; unit: b0 b1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 9: // Rand: b0 ; unit: a b1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 10: // Rand: b1 ; unit: a b0
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_g_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_g_o_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_g_n_k.mData.data());
|
||||
b1_device_buf.ToDevice(b1_g_o_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
float best_perf = .0;
|
||||
float best_time = .0;
|
||||
int not_pass = 0;
|
||||
std::string best_kernel = "";
|
||||
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
|
||||
const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
|
||||
|
||||
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
|
||||
auto gemm = DeviceMHAInstance{};
|
||||
auto invoker_ptr = gemm.MakeInvokerPointer();
|
||||
auto argument_ptr =
|
||||
gemm.MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
G, // Batch,
|
||||
a_g_m_k_strides[1], // StrideA,
|
||||
b0_g_n_k_strides[1], // StrideB0,
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
b1_g_o_n_strides[1], // StrideB1,
|
||||
#else
|
||||
b1_g_o_n_strides[2], // StrideB1,
|
||||
#endif
|
||||
c_g_m_o_strides[1], // StrideC,
|
||||
a_g_m_k_strides[0], // BatchStrideA
|
||||
b0_g_n_k_strides[0], // BatchStrideB0
|
||||
b1_g_o_n_strides[0], // BatchStrideB1
|
||||
c_g_m_o_strides[0], // BatchStrideC
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G;
|
||||
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
|
||||
G;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
if(tflops > best_perf)
|
||||
{
|
||||
best_perf = tflops;
|
||||
best_time = ave_time * 1000;
|
||||
best_kernel = gemm.GetTypeString();
|
||||
}
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
|
||||
|
||||
Tensor<B0DataType> b0_g_k_n({G, K, N});
|
||||
Tensor<B1DataType> b1_g_n_o({G, N, O});
|
||||
Tensor<AccDataType> acc0_g_m_n({G, M, N}); // scratch object after gemm0
|
||||
Tensor<ADataType> a1_g_m_n({G, M, N}); // scratch object after conversion
|
||||
|
||||
// permute
|
||||
b0_g_n_k.ForEach(
|
||||
[&](auto& self, auto idx) { b0_g_k_n(idx[0], idx[2], idx[1]) = self(idx); });
|
||||
b1_g_o_n.ForEach(
|
||||
[&](auto& self, auto idx) { b1_g_n_o(idx[0], idx[2], idx[1]) = self(idx); });
|
||||
|
||||
// gemm 0
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
// Passthrough instead of softmax, DOES involve data type conversion.
|
||||
a1_g_m_n(idx) = ck::type_convert<ADataType, AccDataType>(self(idx));
|
||||
});
|
||||
|
||||
// gemm1
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
|
||||
b1_g_n_o,
|
||||
c_g_m_o_host_result,
|
||||
PassThrough{},
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// default absolute error and relative error is 0.001
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
|
||||
// when BF16 is taken, set absolute error and relative error to 0.01
|
||||
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
|
||||
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
|
||||
{
|
||||
rtol = 1e-2;
|
||||
atol = 1e-2;
|
||||
}
|
||||
|
||||
bool this_run_verification = ck::utils::check_err(c_g_m_o_device_result.mData,
|
||||
c_g_m_o_host_result.mData,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol);
|
||||
printf("Verification: %s, Pass: %s\n",
|
||||
do_verification ? "ON" : "OFF",
|
||||
this_run_verification ? "YES" : "NO");
|
||||
|
||||
if(!this_run_verification)
|
||||
{
|
||||
not_pass = 1;
|
||||
printf("%d th MHA instance verification Failed \n", i.value);
|
||||
}
|
||||
}
|
||||
});
|
||||
std::cout << "---------------------------------------------------------------------------------"
|
||||
"-----------"
|
||||
<< std::endl;
|
||||
std::cout << "Problem Size: G: " << G << ", M: " << M << ", N: " << N << ", K: " << K
|
||||
<< ", O: " << O << std::endl;
|
||||
std::cout << "---------------------------------------------------------------------------------"
|
||||
"-----------"
|
||||
<< std::endl;
|
||||
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
|
||||
<< " us" << std::endl;
|
||||
std::cout << "---------------------------------------------------------------------------------"
|
||||
"-----------"
|
||||
<< std::endl;
|
||||
return not_pass;
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/library/include
|
||||
${PROJECT_SOURCE_DIR}/example/include
|
||||
)
|
||||
|
||||
add_custom_target(examples)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "mask.hpp"
|
||||
#include "rotary.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "topk_softmax_api.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#if 0
|
||||
template <typename T>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "rmsnorm2d_fwd.hpp"
|
||||
#include <cstring>
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "add_rmsnorm2d_rdquant_fwd.hpp"
|
||||
#include <cstring>
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename InputDataType>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "smoothquant.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
// different threshold for different dtype
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "moe_sorting_api.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "moe_smoothquant.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
#include <set>
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <set>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include "fused_moe.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include <json_dump.hpp>
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
#include <type_traits>
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
template <typename T>
|
||||
constexpr const char* DataTypeToString()
|
||||
{
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include "elementwise_common.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include "elementwise_common.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_transpose.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include "elementwise_common.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include "elementwise_common.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
|
||||
#include "batched_transpose_example.hpp"
|
||||
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#if 0
|
||||
template <typename T>
|
||||
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
|
||||
|
||||
@@ -1,700 +0,0 @@
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant"
|
||||
#include "rapidjson/writer.h"
|
||||
#include "rapidjson/stringbuffer.h"
|
||||
#include "rapidjson/document.h"
|
||||
#include "rapidjson/rapidjson.h"
|
||||
// #include <fstream>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#define START_JSON_DUMP_FILE(file_name) \
|
||||
std::string file_str(file_name); \
|
||||
std::ofstream file(file_str); \
|
||||
if(!file.is_open()) \
|
||||
{ \
|
||||
throw std::runtime_error("Could not open file: " + std::string(file_name)); \
|
||||
} \
|
||||
rapidjson::StringBuffer s; \
|
||||
rapidjson::Writer<rapidjson::StringBuffer> writer(s); \
|
||||
writer.StartObject();
|
||||
|
||||
#define END_JSON_DUMP_FILE() \
|
||||
writer.EndObject(); \
|
||||
file << s.GetString(); \
|
||||
file.close(); \
|
||||
std::cout << "Results written to " << file_str << " successfully" << std::endl;
|
||||
|
||||
#define ADD_KEY_VALUE(key, value) add_key_value_pair(writer, key, value);
|
||||
#define ADD_PERF_TO_JSON(_time, tflops, gbytes) add_perf_to_json(writer, _time, tflops, gbytes);
|
||||
|
||||
template <typename T>
|
||||
void add_key_value_pair(rapidjson::Writer<rapidjson::StringBuffer>& writer,
|
||||
const char* key,
|
||||
T value)
|
||||
{
|
||||
writer.Key(key);
|
||||
if constexpr(std::is_same<T, const char*>::value)
|
||||
{
|
||||
writer.String(value, static_cast<rapidjson::SizeType>(std::strlen(value)));
|
||||
}
|
||||
else if constexpr(std::is_same<T, std::string>::value)
|
||||
{
|
||||
writer.String(value.c_str(), static_cast<rapidjson::SizeType>(value.length()));
|
||||
}
|
||||
else if constexpr(std::is_floating_point<T>::value)
|
||||
{
|
||||
writer.Double(static_cast<double>(value));
|
||||
}
|
||||
else if constexpr(std::is_integral<T>::value)
|
||||
{
|
||||
writer.Int64(static_cast<int64_t>(value));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same<T, const char*>::value || std::is_floating_point<T>::value ||
|
||||
std::is_integral<T>::value,
|
||||
"Unsupported type for JSON serialization");
|
||||
}
|
||||
}
|
||||
|
||||
static void add_perf_to_json(rapidjson::Writer<rapidjson::StringBuffer>& writer,
|
||||
float time,
|
||||
float tflops,
|
||||
float gbytes)
|
||||
{
|
||||
std::string roster("perf");
|
||||
writer.String(roster.c_str(), static_cast<rapidjson::SizeType>(roster.length()));
|
||||
|
||||
writer.StartArray();
|
||||
writer.StartObject();
|
||||
|
||||
add_key_value_pair(writer, "time", time);
|
||||
add_key_value_pair(writer, "tflops", tflops);
|
||||
add_key_value_pair(writer, "gbytes", gbytes);
|
||||
|
||||
writer.EndObject();
|
||||
writer.EndArray();
|
||||
}
|
||||
|
||||
// Helper traits to check for static member existence
|
||||
template <typename T, typename = void>
|
||||
struct has_warp_tile_members : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_warp_tile_members<
|
||||
T,
|
||||
std::void_t<decltype(T::M_Warp_Tile), decltype(T::N_Warp_Tile), decltype(T::K_Warp_Tile)>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename GemmConfig,
|
||||
template <typename>
|
||||
typename DTypeTraits>
|
||||
void dump_gemm_json_results(const std::string& json_filename,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
bool persistent,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "gemm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("M", M);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("K", K);
|
||||
ADD_KEY_VALUE("stride_A", stride_A);
|
||||
ADD_KEY_VALUE("stride_B", stride_B);
|
||||
ADD_KEY_VALUE("stride_C", stride_C);
|
||||
ADD_KEY_VALUE("A_layout", ALayout::name);
|
||||
ADD_KEY_VALUE("B_layout", BLayout::name);
|
||||
ADD_KEY_VALUE("C_layout", CLayout::name);
|
||||
using TraitsADataType = DTypeTraits<ADataType>;
|
||||
using TraitsBDataType = DTypeTraits<BDataType>;
|
||||
using TraitsCDataType = DTypeTraits<CDataType>;
|
||||
ADD_KEY_VALUE("A_type", TraitsADataType::name);
|
||||
ADD_KEY_VALUE("B_type", TraitsBDataType::name);
|
||||
ADD_KEY_VALUE("C_type", TraitsCDataType::name);
|
||||
ADD_KEY_VALUE("structured_sparsity", GemmConfig::UseStructuredSparsity ? "on" : "off");
|
||||
|
||||
if constexpr(has_warp_tile_members<GemmConfig>::value)
|
||||
{
|
||||
ADD_KEY_VALUE("warp_tile",
|
||||
std::to_string(GemmConfig::M_Warp_Tile) + "x" +
|
||||
std::to_string(GemmConfig::N_Warp_Tile) + "x" +
|
||||
std::to_string(GemmConfig::K_Warp_Tile));
|
||||
}
|
||||
ADD_KEY_VALUE("persistent", persistent ? "on" : "off");
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_batched_gemm_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int batch_stride_A,
|
||||
int batch_stride_B,
|
||||
int batch_stride_C,
|
||||
int batch_count,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "batched_gemm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("op_name", op_name);
|
||||
ADD_KEY_VALUE("M", M);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("K", K);
|
||||
ADD_KEY_VALUE("stride_A", stride_A);
|
||||
ADD_KEY_VALUE("stride_B", stride_B);
|
||||
ADD_KEY_VALUE("stride_C", stride_C);
|
||||
ADD_KEY_VALUE("batch_stride_A", batch_stride_A);
|
||||
ADD_KEY_VALUE("batch_stride_B", batch_stride_B);
|
||||
ADD_KEY_VALUE("batch_stride_C", batch_stride_C);
|
||||
ADD_KEY_VALUE("batch_count", batch_count);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void dump_grouped_gemm_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int group_count,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "grouped_gemm")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("op_name", op_name);
|
||||
ADD_KEY_VALUE("group_count", group_count);
|
||||
ADD_KEY_VALUE("A_layout", ALayout::name);
|
||||
ADD_KEY_VALUE("B_layout", BLayout::name);
|
||||
ADD_KEY_VALUE("C_layout", CLayout::name);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_flatmm_json_results(const std::string& json_filename,
|
||||
const std::string& datatype,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int kbatch,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "flatmm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("DataType", datatype);
|
||||
ADD_KEY_VALUE("M", M);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("K", K);
|
||||
ADD_KEY_VALUE("StrideA", stride_A);
|
||||
ADD_KEY_VALUE("StrideB", stride_B);
|
||||
ADD_KEY_VALUE("StrideC", stride_C);
|
||||
ADD_KEY_VALUE("kbatch", kbatch);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideD0,
|
||||
int StrideD1,
|
||||
int StrideE,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "gemm_multi_d_fp16")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("op_name", op_name);
|
||||
ADD_KEY_VALUE("M", M);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("K", K);
|
||||
ADD_KEY_VALUE("StrideA", StrideA);
|
||||
ADD_KEY_VALUE("StrideB", StrideB);
|
||||
ADD_KEY_VALUE("StrideD0", StrideD0);
|
||||
ADD_KEY_VALUE("StrideD1", StrideD1);
|
||||
ADD_KEY_VALUE("StrideE", StrideE);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_elementwise_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
int grid_size,
|
||||
int block_size,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "elementwise")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", prec);
|
||||
ADD_KEY_VALUE("grid_size", grid_size);
|
||||
ADD_KEY_VALUE("block_size", block_size);
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
const std::string& prec_sm,
|
||||
const std::string& prec_sy,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "layernorm2d_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec_i", prec_i);
|
||||
ADD_KEY_VALUE("prec_o", prec_o);
|
||||
ADD_KEY_VALUE("prec_sm", prec_sm);
|
||||
ADD_KEY_VALUE("prec_sy", prec_sy);
|
||||
ADD_KEY_VALUE("m", m);
|
||||
ADD_KEY_VALUE("n", n);
|
||||
ADD_KEY_VALUE("x_stride", x_stride);
|
||||
ADD_KEY_VALUE("xr_stride", xr_stride);
|
||||
ADD_KEY_VALUE("y_stride", y_stride);
|
||||
ADD_KEY_VALUE("yr_stride", yr_stride);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
template <typename DataType, template <typename> typename DTypeTraits>
|
||||
void dump_reduce_json_results(const std::string& json_filename,
|
||||
int N,
|
||||
int C,
|
||||
int H,
|
||||
int W,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "reduce")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
using Traits = DTypeTraits<DataType>;
|
||||
ADD_KEY_VALUE("data_type", Traits::name);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("C", C);
|
||||
ADD_KEY_VALUE("H", H);
|
||||
ADD_KEY_VALUE("W", W);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_permute_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "permute")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("data_type", data_type);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_topk_softmax_json(const std::string& json_filename,
|
||||
const std::string& input_prec,
|
||||
const std::string& weight_prec,
|
||||
int tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int stride_input,
|
||||
int stride_output,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "topk_softmax")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("input_prec", input_prec);
|
||||
ADD_KEY_VALUE("weight_prec", weight_prec);
|
||||
ADD_KEY_VALUE("tokens", tokens);
|
||||
ADD_KEY_VALUE("experts", experts);
|
||||
ADD_KEY_VALUE("topk", topk);
|
||||
ADD_KEY_VALUE("stride_input", stride_input);
|
||||
ADD_KEY_VALUE("stride_output", stride_output);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
int use_model_sensitive_rmsnorm,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "rmsnorm2d_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", prec_str);
|
||||
ADD_KEY_VALUE("m", m);
|
||||
ADD_KEY_VALUE("n", n);
|
||||
ADD_KEY_VALUE("x_stride", x_stride);
|
||||
ADD_KEY_VALUE("xr_stride", xr_stride);
|
||||
ADD_KEY_VALUE("y_stride", y_stride);
|
||||
ADD_KEY_VALUE("yr_stride", yr_stride);
|
||||
ADD_KEY_VALUE("use_model_sensitive_rmsnorm", use_model_sensitive_rmsnorm);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_add_rmsnorm2d_rdquant_fwd_json(
|
||||
const std::string& json_filename,
|
||||
const std::string& input_data_type,
|
||||
const std::string& quantized_data_type,
|
||||
int m,
|
||||
int n,
|
||||
int stride,
|
||||
float epsilon,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("input_data_type", input_data_type);
|
||||
ADD_KEY_VALUE("quantized_data_type", quantized_data_type);
|
||||
ADD_KEY_VALUE("m", m);
|
||||
ADD_KEY_VALUE("n", n);
|
||||
ADD_KEY_VALUE("stride", stride);
|
||||
ADD_KEY_VALUE("epsilon", epsilon);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int y_stride,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "smoothquant")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", prec_str);
|
||||
ADD_KEY_VALUE("m", m);
|
||||
ADD_KEY_VALUE("n", n);
|
||||
ADD_KEY_VALUE("x_stride", x_stride);
|
||||
ADD_KEY_VALUE("y_stride", y_stride);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_moe_sorting_json(const std::string& json_filename,
|
||||
const std::string& index_prec,
|
||||
const std::string& weight_prec,
|
||||
const std::string& workspace_size,
|
||||
int dispatch_policy,
|
||||
int tokens,
|
||||
int num_experts,
|
||||
int topk,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "moe_sorting")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("index_prec", index_prec);
|
||||
ADD_KEY_VALUE("weight_prec", weight_prec);
|
||||
ADD_KEY_VALUE("workspace_size", workspace_size);
|
||||
ADD_KEY_VALUE("dispatch_policy", dispatch_policy);
|
||||
ADD_KEY_VALUE("tokens", tokens);
|
||||
ADD_KEY_VALUE("num_experts", num_experts);
|
||||
ADD_KEY_VALUE("topk", topk);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_batched_transpose_json(const std::string& json_filename,
|
||||
int N,
|
||||
int C,
|
||||
int H,
|
||||
int W,
|
||||
const std::string& layout_in,
|
||||
const std::string& layout_out,
|
||||
const std::string& prec,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "batched_transpose")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("C", C);
|
||||
ADD_KEY_VALUE("H", H);
|
||||
ADD_KEY_VALUE("W", W);
|
||||
ADD_KEY_VALUE("LayoutIn", layout_in);
|
||||
ADD_KEY_VALUE("LayoutOut", layout_out);
|
||||
ADD_KEY_VALUE("Precision", prec);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_moe_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
int tokens,
|
||||
int hidden_size,
|
||||
int stride,
|
||||
int experts,
|
||||
int topk,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "moe_smoothquant")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec_i", prec_i);
|
||||
ADD_KEY_VALUE("prec_o", prec_o);
|
||||
ADD_KEY_VALUE("tokens", tokens);
|
||||
ADD_KEY_VALUE("hidden_size", hidden_size);
|
||||
ADD_KEY_VALUE("stride", stride);
|
||||
ADD_KEY_VALUE("experts", experts);
|
||||
ADD_KEY_VALUE("topk", topk);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fused_moe_json(const std::string& json_filename,
|
||||
const std::string& api_str,
|
||||
const std::string& prec_str,
|
||||
int tokens,
|
||||
bool is_local_token,
|
||||
int local_tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int hidden_size,
|
||||
int intermediate_size,
|
||||
int stride,
|
||||
int block_m,
|
||||
int activation,
|
||||
bool gate_only,
|
||||
bool fused_quant,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float tb_per_sec,
|
||||
const std::string& kernel_name = "fused_moe")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("api", api_str);
|
||||
ADD_KEY_VALUE("prec", prec_str);
|
||||
ADD_KEY_VALUE("tokens", tokens);
|
||||
if(is_local_token)
|
||||
{
|
||||
ADD_KEY_VALUE("local_tokens", local_tokens);
|
||||
}
|
||||
ADD_KEY_VALUE("experts", experts);
|
||||
ADD_KEY_VALUE("topk", topk);
|
||||
ADD_KEY_VALUE("hidden_size", hidden_size);
|
||||
ADD_KEY_VALUE("intermediate_size", intermediate_size);
|
||||
ADD_KEY_VALUE("stride", stride);
|
||||
ADD_KEY_VALUE("block_m", block_m);
|
||||
ADD_KEY_VALUE("activation", activation);
|
||||
ADD_KEY_VALUE("gate_only", gate_only);
|
||||
ADD_KEY_VALUE("fused_quant", fused_quant);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, (tb_per_sec * 1024.0f))
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
const std::string& mode,
|
||||
const std::string& io_layout,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_qs,
|
||||
int seqlen_ks,
|
||||
int seqlen_kpads,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale_s,
|
||||
float p_drop,
|
||||
bool lse,
|
||||
bool squant,
|
||||
const std::string& bais,
|
||||
const std::string& vlayout,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", prec);
|
||||
ADD_KEY_VALUE("mode", mode);
|
||||
ADD_KEY_VALUE("io_layout", io_layout);
|
||||
ADD_KEY_VALUE("batch", batch);
|
||||
ADD_KEY_VALUE("nhead", nhead);
|
||||
ADD_KEY_VALUE("nhead_k", nhead_k);
|
||||
ADD_KEY_VALUE("seqlen_q", seqlen_qs);
|
||||
ADD_KEY_VALUE("seqlen_k", seqlen_ks);
|
||||
ADD_KEY_VALUE("seqlen_kpads", seqlen_kpads);
|
||||
ADD_KEY_VALUE("hdim_q", hdim_q);
|
||||
ADD_KEY_VALUE("hdim_v", hdim_v);
|
||||
ADD_KEY_VALUE("scale_s", scale_s);
|
||||
ADD_KEY_VALUE("p_drop", p_drop);
|
||||
ADD_KEY_VALUE("lse", lse);
|
||||
ADD_KEY_VALUE("squant", squant);
|
||||
ADD_KEY_VALUE("bias", bais);
|
||||
ADD_KEY_VALUE("vlayout", vlayout);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fmha_bwd_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
const std::string& mode,
|
||||
const std::string& i_perm,
|
||||
const std::string& o_perm,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale,
|
||||
const std::string& bias,
|
||||
bool use_dbias,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
bool deterministic,
|
||||
const std::string& mask,
|
||||
int mask_left,
|
||||
int mask_right,
|
||||
int workspace_size,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_bwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", data_type);
|
||||
ADD_KEY_VALUE("mode", mode);
|
||||
ADD_KEY_VALUE("i_perm", i_perm);
|
||||
ADD_KEY_VALUE("o_perm", o_perm);
|
||||
ADD_KEY_VALUE("batch", batch);
|
||||
ADD_KEY_VALUE("nhead", nhead);
|
||||
ADD_KEY_VALUE("nhead_k", nhead_k);
|
||||
ADD_KEY_VALUE("seqlen_q", seqlen_q);
|
||||
ADD_KEY_VALUE("seqlen_k", seqlen_k);
|
||||
ADD_KEY_VALUE("hdim_q", hdim_q);
|
||||
ADD_KEY_VALUE("hdim_v", hdim_v);
|
||||
ADD_KEY_VALUE("scale", scale);
|
||||
ADD_KEY_VALUE("bias", bias);
|
||||
ADD_KEY_VALUE("use_dbias", use_dbias);
|
||||
ADD_KEY_VALUE("p_drop", p_drop);
|
||||
ADD_KEY_VALUE("s_randval", s_randval);
|
||||
ADD_KEY_VALUE("deterministic", deterministic ? "true" : "false");
|
||||
ADD_KEY_VALUE("mask", mask);
|
||||
ADD_KEY_VALUE("mask_left", mask_left);
|
||||
ADD_KEY_VALUE("mask_right", mask_right);
|
||||
ADD_KEY_VALUE("workspace_size", workspace_size);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
Reference in New Issue
Block a user