mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Merge commit '7330ec37ee3b8cf2d54630372dfe9e86a893e4f5' into develop
This commit is contained in:
@@ -27,7 +27,8 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
constexpr auto BlockGemmPipeline_Selector()
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
@@ -50,7 +51,8 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
TransposeC>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
@@ -72,7 +74,8 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
TransposeC>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -277,6 +277,21 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// transposed WMMA output C' = B' * A'
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
|
||||
{
|
||||
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
||||
|
||||
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
||||
// |NThreadPerSubGroup |MAccVgprs
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
|
||||
@@ -31,7 +31,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1
|
||||
{
|
||||
};
|
||||
@@ -53,7 +54,8 @@ template <index_t BlockSize,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -72,7 +74,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -90,8 +93,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
KPack,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -110,7 +113,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
|
||||
using Base::A_K1;
|
||||
@@ -329,7 +333,8 @@ template <index_t BlockSize,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -348,7 +353,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -366,8 +372,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
KPack,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -386,7 +392,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
|
||||
|
||||
@@ -31,7 +31,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v3
|
||||
{
|
||||
};
|
||||
@@ -53,7 +54,8 @@ template <index_t BlockSize,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -72,7 +74,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -90,7 +93,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -109,7 +113,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
|
||||
using Base::A_K1;
|
||||
@@ -128,6 +133,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
|
||||
using Base::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
|
||||
|
||||
using Base::a_block_desc_k0_m0_m1_m2_k1;
|
||||
using Base::b_block_desc_k0_n0_n1_n2_k1;
|
||||
@@ -145,8 +152,21 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
if(BlockHasHotloop(num_loop))
|
||||
{
|
||||
return TailNumber::Full;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(num_loop == 1)
|
||||
{
|
||||
return TailNumber::Odd;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TailNumber::Even;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
@@ -362,12 +382,15 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
// Global prefetch 2, perform when at least 2 loops exist.
|
||||
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
}
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
@@ -379,7 +402,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
// Main body, perform when at least 3 loops exist.
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
@@ -448,10 +471,62 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
|
||||
// Pre-tail, perform when at least 2 loops exist.
|
||||
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// No RunRead or MoveSrcSliceWindow here, already finished them all!
|
||||
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// Tail, always perform.
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
|
||||
@@ -0,0 +1,788 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b0_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t c_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
|
||||
arg.p_a_grid + a_batch_offset,
|
||||
arg.p_b0_grid + b0_batch_offset,
|
||||
arg.p_b1_grid + b1_batch_offset,
|
||||
arg.p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op,
|
||||
arg.b0_element_op,
|
||||
arg.acc_element_op,
|
||||
arg.b1_element_op,
|
||||
arg.c_element_op,
|
||||
arg.block_2_ctile_map);
|
||||
#else
|
||||
ignore = arg;
|
||||
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
}
|
||||
|
||||
// Computes C = A * B0 * B1
|
||||
// MN = MK * KL * LN
|
||||
// ^^^^^^ (Acc0)
|
||||
// ^^^^^^^^^^^ (Acc1)
|
||||
template <typename ALayout,
|
||||
typename B0layout,
|
||||
typename B1Layout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t LPerBlock, // Gemm0NPerBlock
|
||||
ck::index_t KPerBlock, // Gemm0KPerBlock
|
||||
ck::index_t NPerBlock, // Gemm1NPerBlock
|
||||
ck::index_t LTilePerBlock, // Gemm1KPerBlock
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t L1, // B1K1
|
||||
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
|
||||
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
|
||||
ck::index_t MRepeat, // Gemm0/1 MWmmaPerWave or Mrepeat
|
||||
ck::index_t LRepeat, // Gemm0 NWmmaPerWave or Nrepeat
|
||||
ck::index_t NRepeat, // Gemm1 NWmmaPerWave or Nrepeat
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename B0BlockTransferThreadClusterLengths_K0_L_K1,
|
||||
typename B0BlockTransferThreadClusterArrangeOrder,
|
||||
typename B0BlockTransferSrcAccessOrder,
|
||||
ck::index_t B0BlockTransferSrcVectorDim,
|
||||
ck::index_t B0BlockTransferSrcScalarPerVector,
|
||||
ck::index_t B0BlockTransferDstScalarPerVector_K1,
|
||||
bool B0BlockLdsAddExtraL,
|
||||
typename B1BlockTransferThreadClusterLengths_L0_N_L1,
|
||||
typename B1BlockTransferThreadClusterArrangeOrder,
|
||||
typename B1BlockTransferSrcAccessOrder,
|
||||
ck::index_t B1BlockTransferSrcVectorDim,
|
||||
ck::index_t B1BlockTransferSrcScalarPerVector,
|
||||
ck::index_t B1BlockTransferDstScalarPerVector_L1,
|
||||
bool B1BlockLdsAddExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1>
|
||||
struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALayout,
|
||||
B0layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmGemm_Wmma_CShuffleV3;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
// To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal
|
||||
// to LPerWmma (A.k.a Gemm0 NPerWmma).
|
||||
static constexpr index_t NPerWmma = LPerWmma;
|
||||
|
||||
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
|
||||
// Transform operator or just not use one at all.
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
|
||||
Sequence<1, 1, 1, 1, 1>,
|
||||
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ASpec
|
||||
TensorSpecialization::Default, // B0Spec
|
||||
TensorSpecialization::Default, // B1Spec
|
||||
TensorSpecialization::Default>; // CSpec
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
|
||||
const std::array<index_t, 3>& a_g_m_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
|
||||
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
|
||||
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
|
||||
{
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
|
||||
Number<L1>{});
|
||||
}
|
||||
|
||||
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
|
||||
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
|
||||
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
|
||||
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideC_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB0_;
|
||||
index_t BatchStrideB1_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
B0DataType,
|
||||
AccDataType, // Acc0DataType
|
||||
B1DataType,
|
||||
AccDataType, // Acc1DataType
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
// ElementwiseOp Family
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc,
|
||||
B0GridDesc,
|
||||
B1GridDesc,
|
||||
CGridDesc_M_N,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
NPerBlock,
|
||||
LTilePerBlock,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
LRepeat,
|
||||
NRepeat,
|
||||
// ThreadCluster Family
|
||||
BlockSize,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
true,
|
||||
ABlockLdsAddExtraM,
|
||||
B0BlockTransferThreadClusterLengths_K0_L_K1,
|
||||
B0BlockTransferThreadClusterArrangeOrder,
|
||||
B0BlockTransferSrcAccessOrder,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferDstScalarPerVector_K1,
|
||||
true,
|
||||
B0BlockLdsAddExtraL,
|
||||
B1BlockTransferThreadClusterLengths_L0_N_L1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferDstScalarPerVector_L1,
|
||||
false,
|
||||
B1BlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
Transform::matrix_padder.PadN,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
|
||||
struct RawArg : public BaseArgument
|
||||
{
|
||||
using arr3 = std::array<ck::index_t, 3>;
|
||||
|
||||
RawArg(const ADataType* p_a_grid_,
|
||||
const B0DataType* p_b0_grid_,
|
||||
const B1DataType* p_b1_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t O_,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB0,
|
||||
index_t StrideB1,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op_,
|
||||
B0ElementwiseOperation b0_element_op_,
|
||||
AccElementwiseOperation acc_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b0_grid{p_b0_grid_},
|
||||
p_b1_grid{p_b1_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
O{O_},
|
||||
batch_count{Batch},
|
||||
a_element_op{a_element_op_},
|
||||
b0_element_op{b0_element_op_},
|
||||
acc_element_op{acc_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
c_element_op{c_element_op_},
|
||||
compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC}
|
||||
{
|
||||
|
||||
a_g_m_k_lengths = arr3{batch_count, M, K};
|
||||
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
|
||||
|
||||
b0_g_n_k_lengths = arr3{batch_count, N, K};
|
||||
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
|
||||
|
||||
b1_g_o_n_lengths = arr3{batch_count, O, N};
|
||||
b1_g_o_n_strides =
|
||||
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
|
||||
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
|
||||
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
|
||||
|
||||
c_g_m_o_lengths = arr3{batch_count, M, O};
|
||||
c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O]
|
||||
|
||||
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
|
||||
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
|
||||
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
|
||||
}
|
||||
// Pointers
|
||||
const ADataType* p_a_grid;
|
||||
const B0DataType* p_b0_grid;
|
||||
const B1DataType* p_b1_grid;
|
||||
CDataType* p_c_grid;
|
||||
|
||||
// Raw Problem Size
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
index_t batch_count;
|
||||
|
||||
arr3 a_g_m_k_lengths;
|
||||
arr3 a_g_m_k_strides;
|
||||
arr3 b0_g_n_k_lengths;
|
||||
arr3 b0_g_n_k_strides;
|
||||
arr3 b1_g_o_n_lengths;
|
||||
arr3 b1_g_o_n_strides;
|
||||
arr3 c_g_m_o_lengths;
|
||||
arr3 c_g_m_o_strides;
|
||||
|
||||
AElementwiseOperation a_element_op;
|
||||
B0ElementwiseOperation b0_element_op;
|
||||
AccElementwiseOperation acc_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CElementwiseOperation c_element_op;
|
||||
|
||||
// Grid descriptors and other mem calculators
|
||||
AGridDesc a_grid_desc;
|
||||
B0GridDesc b0_grid_desc;
|
||||
B1GridDesc b1_grid_desc;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map;
|
||||
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
|
||||
{
|
||||
// Print lambda with env check and printf() style formmating.
|
||||
const char* curFunc = __func__;
|
||||
auto print = [&curFunc](const char* format, ...) -> void {
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
#endif
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
std::vfprintf(stdout, format, args);
|
||||
va_end(args);
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
print("DeviceOp: Arch err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
|
||||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
print("DeviceOp: gfx 11 does not support fp8\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
print("DeviceOp: Acc0 Type err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: A layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B layout must be Column\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
|
||||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B1 layout must be Column or Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<CLayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: C layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Other padding modes have not been tested and do not get checked individually.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default &&
|
||||
GemmSpec != GemmSpecialization::MNKOPadding)
|
||||
{
|
||||
print("Padding mode must be default or MNKO\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Per wmma dimensions not equal to 16 are very untested.
|
||||
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
|
||||
{
|
||||
print("M, L, N per Wmma must be 16\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.c_grid_desc_m_n,
|
||||
arg.block_2_ctile_map))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
|
||||
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
|
||||
const auto c_extent_lowest = arg.O;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
print("DeviceOp: Data Transfer Vector scalar err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
|
||||
const auto b0_stride_lowest =
|
||||
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
|
||||
const auto b1_stride_lowest =
|
||||
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
|
||||
const auto c_stride_lowest = arg.c_g_m_o_strides[2];
|
||||
|
||||
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
c_stride_lowest == 1))
|
||||
{
|
||||
print("DeviceOp: Data Vectorize transfer err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
|
||||
}
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::RawArg;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
|
||||
|
||||
const index_t grid_size = arg.batch_count * M0 * N0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
|
||||
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
|
||||
constexpr TailNumber tn = tail_number;
|
||||
|
||||
const auto kernel =
|
||||
kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp, GridwiseOp, has_loop, tn>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
|
||||
};
|
||||
|
||||
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
|
||||
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid pipeline version!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b0,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t O,
|
||||
ck::index_t Batch,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB0,
|
||||
ck::index_t StrideB1,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB0,
|
||||
ck::index_t BatchStrideB1,
|
||||
ck::index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
BatchStrideC,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static constexpr const char* DataTypeToString()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
{
|
||||
return "fp32";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::half_t>)
|
||||
{
|
||||
return "fp16";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::bhalf_t>)
|
||||
{
|
||||
return "bf16";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::f8_t>)
|
||||
{
|
||||
return "fp8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::bf8_t>)
|
||||
{
|
||||
return "bf8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, int32_t>)
|
||||
{
|
||||
return "int32";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, int8_t>)
|
||||
{
|
||||
return "int8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::int4_t>)
|
||||
{
|
||||
return "int4";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmGemm_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< ALayout::name[0]
|
||||
<< B0layout::name[0]
|
||||
<< B1Layout::name[0]
|
||||
<< CLayout::name[0] << ", "
|
||||
<< "A " << DataTypeToString<ADataType>() << ", "
|
||||
<< "B0 " << DataTypeToString<B0DataType>() << ", "
|
||||
<< "B1 " << DataTypeToString<B1DataType>() << ", "
|
||||
<< "C " << DataTypeToString<CDataType>() << ", "
|
||||
<< "Acc " << DataTypeToString<AccDataType>() << ", "
|
||||
<< "Cshuf " << DataTypeToString<CShuffleDataType>() << ", "
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< LPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< LTilePerBlock << ", "
|
||||
<< L1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec)
|
||||
<< ">"
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseOp::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -243,6 +243,30 @@ inline __host__ __device__ constexpr half_t type_convert_sp<half_t, int>(int x)
|
||||
return u.fp16;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr int type_convert_sp<int, f8_t>(f8_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
f8_t fp8;
|
||||
int int32;
|
||||
} u = {x};
|
||||
|
||||
return u.int32;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr f8_t type_convert_sp<f8_t, int>(int x)
|
||||
{
|
||||
union
|
||||
{
|
||||
int int32;
|
||||
f8_t fp8;
|
||||
} u = {x};
|
||||
|
||||
return u.fp8;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr int type_convert_sp<int, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
|
||||
700
include/ck_tile/utility/json_dump.hpp
Normal file
700
include/ck_tile/utility/json_dump.hpp
Normal file
@@ -0,0 +1,700 @@
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant"
|
||||
#include "rapidjson/writer.h"
|
||||
#include "rapidjson/stringbuffer.h"
|
||||
#include "rapidjson/document.h"
|
||||
#include "rapidjson/rapidjson.h"
|
||||
// #include <fstream>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#define START_JSON_DUMP_FILE(file_name) \
|
||||
std::string file_str(file_name); \
|
||||
std::ofstream file(file_str); \
|
||||
if(!file.is_open()) \
|
||||
{ \
|
||||
throw std::runtime_error("Could not open file: " + std::string(file_name)); \
|
||||
} \
|
||||
rapidjson::StringBuffer s; \
|
||||
rapidjson::Writer<rapidjson::StringBuffer> writer(s); \
|
||||
writer.StartObject();
|
||||
|
||||
#define END_JSON_DUMP_FILE() \
|
||||
writer.EndObject(); \
|
||||
file << s.GetString(); \
|
||||
file.close(); \
|
||||
std::cout << "Results written to " << file_str << " successfully" << std::endl;
|
||||
|
||||
#define ADD_KEY_VALUE(key, value) add_key_value_pair(writer, key, value);
|
||||
#define ADD_PERF_TO_JSON(_time, tflops, gbytes) add_perf_to_json(writer, _time, tflops, gbytes);
|
||||
|
||||
template <typename T>
|
||||
void add_key_value_pair(rapidjson::Writer<rapidjson::StringBuffer>& writer,
|
||||
const char* key,
|
||||
T value)
|
||||
{
|
||||
writer.Key(key);
|
||||
if constexpr(std::is_same<T, const char*>::value)
|
||||
{
|
||||
writer.String(value, static_cast<rapidjson::SizeType>(std::strlen(value)));
|
||||
}
|
||||
else if constexpr(std::is_same<T, std::string>::value)
|
||||
{
|
||||
writer.String(value.c_str(), static_cast<rapidjson::SizeType>(value.length()));
|
||||
}
|
||||
else if constexpr(std::is_floating_point<T>::value)
|
||||
{
|
||||
writer.Double(static_cast<double>(value));
|
||||
}
|
||||
else if constexpr(std::is_integral<T>::value)
|
||||
{
|
||||
writer.Int64(static_cast<int64_t>(value));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same<T, const char*>::value || std::is_floating_point<T>::value ||
|
||||
std::is_integral<T>::value,
|
||||
"Unsupported type for JSON serialization");
|
||||
}
|
||||
}
|
||||
|
||||
static void add_perf_to_json(rapidjson::Writer<rapidjson::StringBuffer>& writer,
|
||||
float time,
|
||||
float tflops,
|
||||
float gbytes)
|
||||
{
|
||||
std::string roster("perf");
|
||||
writer.String(roster.c_str(), static_cast<rapidjson::SizeType>(roster.length()));
|
||||
|
||||
writer.StartArray();
|
||||
writer.StartObject();
|
||||
|
||||
add_key_value_pair(writer, "time", time);
|
||||
add_key_value_pair(writer, "tflops", tflops);
|
||||
add_key_value_pair(writer, "gbytes", gbytes);
|
||||
|
||||
writer.EndObject();
|
||||
writer.EndArray();
|
||||
}
|
||||
|
||||
// Helper traits to check for static member existence
|
||||
template <typename T, typename = void>
|
||||
struct has_warp_tile_members : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_warp_tile_members<
|
||||
T,
|
||||
std::void_t<decltype(T::M_Warp_Tile), decltype(T::N_Warp_Tile), decltype(T::K_Warp_Tile)>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename GemmConfig,
|
||||
template <typename>
|
||||
typename DTypeTraits>
|
||||
void dump_gemm_json_results(const std::string& json_filename,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
bool persistent,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "gemm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("M", M);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("K", K);
|
||||
ADD_KEY_VALUE("stride_A", stride_A);
|
||||
ADD_KEY_VALUE("stride_B", stride_B);
|
||||
ADD_KEY_VALUE("stride_C", stride_C);
|
||||
ADD_KEY_VALUE("A_layout", ALayout::name);
|
||||
ADD_KEY_VALUE("B_layout", BLayout::name);
|
||||
ADD_KEY_VALUE("C_layout", CLayout::name);
|
||||
using TraitsADataType = DTypeTraits<ADataType>;
|
||||
using TraitsBDataType = DTypeTraits<BDataType>;
|
||||
using TraitsCDataType = DTypeTraits<CDataType>;
|
||||
ADD_KEY_VALUE("A_type", TraitsADataType::name);
|
||||
ADD_KEY_VALUE("B_type", TraitsBDataType::name);
|
||||
ADD_KEY_VALUE("C_type", TraitsCDataType::name);
|
||||
ADD_KEY_VALUE("structured_sparsity", GemmConfig::UseStructuredSparsity ? "on" : "off");
|
||||
|
||||
if constexpr(has_warp_tile_members<GemmConfig>::value)
|
||||
{
|
||||
ADD_KEY_VALUE("warp_tile",
|
||||
std::to_string(GemmConfig::M_Warp_Tile) + "x" +
|
||||
std::to_string(GemmConfig::N_Warp_Tile) + "x" +
|
||||
std::to_string(GemmConfig::K_Warp_Tile));
|
||||
}
|
||||
ADD_KEY_VALUE("persistent", persistent ? "on" : "off");
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_batched_gemm_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int batch_stride_A,
|
||||
int batch_stride_B,
|
||||
int batch_stride_C,
|
||||
int batch_count,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "batched_gemm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("op_name", op_name);
|
||||
ADD_KEY_VALUE("M", M);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("K", K);
|
||||
ADD_KEY_VALUE("stride_A", stride_A);
|
||||
ADD_KEY_VALUE("stride_B", stride_B);
|
||||
ADD_KEY_VALUE("stride_C", stride_C);
|
||||
ADD_KEY_VALUE("batch_stride_A", batch_stride_A);
|
||||
ADD_KEY_VALUE("batch_stride_B", batch_stride_B);
|
||||
ADD_KEY_VALUE("batch_stride_C", batch_stride_C);
|
||||
ADD_KEY_VALUE("batch_count", batch_count);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void dump_grouped_gemm_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int group_count,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "grouped_gemm")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("op_name", op_name);
|
||||
ADD_KEY_VALUE("group_count", group_count);
|
||||
ADD_KEY_VALUE("A_layout", ALayout::name);
|
||||
ADD_KEY_VALUE("B_layout", BLayout::name);
|
||||
ADD_KEY_VALUE("C_layout", CLayout::name);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_flatmm_json_results(const std::string& json_filename,
|
||||
const std::string& datatype,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int kbatch,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "flatmm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("DataType", datatype);
|
||||
ADD_KEY_VALUE("M", M);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("K", K);
|
||||
ADD_KEY_VALUE("StrideA", stride_A);
|
||||
ADD_KEY_VALUE("StrideB", stride_B);
|
||||
ADD_KEY_VALUE("StrideC", stride_C);
|
||||
ADD_KEY_VALUE("kbatch", kbatch);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideD0,
|
||||
int StrideD1,
|
||||
int StrideE,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "gemm_multi_d_fp16")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("op_name", op_name);
|
||||
ADD_KEY_VALUE("M", M);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("K", K);
|
||||
ADD_KEY_VALUE("StrideA", StrideA);
|
||||
ADD_KEY_VALUE("StrideB", StrideB);
|
||||
ADD_KEY_VALUE("StrideD0", StrideD0);
|
||||
ADD_KEY_VALUE("StrideD1", StrideD1);
|
||||
ADD_KEY_VALUE("StrideE", StrideE);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_elementwise_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
int grid_size,
|
||||
int block_size,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "elementwise")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", prec);
|
||||
ADD_KEY_VALUE("grid_size", grid_size);
|
||||
ADD_KEY_VALUE("block_size", block_size);
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
const std::string& prec_sm,
|
||||
const std::string& prec_sy,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "layernorm2d_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec_i", prec_i);
|
||||
ADD_KEY_VALUE("prec_o", prec_o);
|
||||
ADD_KEY_VALUE("prec_sm", prec_sm);
|
||||
ADD_KEY_VALUE("prec_sy", prec_sy);
|
||||
ADD_KEY_VALUE("m", m);
|
||||
ADD_KEY_VALUE("n", n);
|
||||
ADD_KEY_VALUE("x_stride", x_stride);
|
||||
ADD_KEY_VALUE("xr_stride", xr_stride);
|
||||
ADD_KEY_VALUE("y_stride", y_stride);
|
||||
ADD_KEY_VALUE("yr_stride", yr_stride);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
template <typename DataType, template <typename> typename DTypeTraits>
|
||||
void dump_reduce_json_results(const std::string& json_filename,
|
||||
int N,
|
||||
int C,
|
||||
int H,
|
||||
int W,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "reduce")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
using Traits = DTypeTraits<DataType>;
|
||||
ADD_KEY_VALUE("data_type", Traits::name);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("C", C);
|
||||
ADD_KEY_VALUE("H", H);
|
||||
ADD_KEY_VALUE("W", W);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_permute_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "permute")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("data_type", data_type);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_topk_softmax_json(const std::string& json_filename,
|
||||
const std::string& input_prec,
|
||||
const std::string& weight_prec,
|
||||
int tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int stride_input,
|
||||
int stride_output,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "topk_softmax")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("input_prec", input_prec);
|
||||
ADD_KEY_VALUE("weight_prec", weight_prec);
|
||||
ADD_KEY_VALUE("tokens", tokens);
|
||||
ADD_KEY_VALUE("experts", experts);
|
||||
ADD_KEY_VALUE("topk", topk);
|
||||
ADD_KEY_VALUE("stride_input", stride_input);
|
||||
ADD_KEY_VALUE("stride_output", stride_output);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
int use_model_sensitive_rmsnorm,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "rmsnorm2d_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", prec_str);
|
||||
ADD_KEY_VALUE("m", m);
|
||||
ADD_KEY_VALUE("n", n);
|
||||
ADD_KEY_VALUE("x_stride", x_stride);
|
||||
ADD_KEY_VALUE("xr_stride", xr_stride);
|
||||
ADD_KEY_VALUE("y_stride", y_stride);
|
||||
ADD_KEY_VALUE("yr_stride", yr_stride);
|
||||
ADD_KEY_VALUE("use_model_sensitive_rmsnorm", use_model_sensitive_rmsnorm);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_add_rmsnorm2d_rdquant_fwd_json(
|
||||
const std::string& json_filename,
|
||||
const std::string& input_data_type,
|
||||
const std::string& quantized_data_type,
|
||||
int m,
|
||||
int n,
|
||||
int stride,
|
||||
float epsilon,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("input_data_type", input_data_type);
|
||||
ADD_KEY_VALUE("quantized_data_type", quantized_data_type);
|
||||
ADD_KEY_VALUE("m", m);
|
||||
ADD_KEY_VALUE("n", n);
|
||||
ADD_KEY_VALUE("stride", stride);
|
||||
ADD_KEY_VALUE("epsilon", epsilon);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int y_stride,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "smoothquant")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", prec_str);
|
||||
ADD_KEY_VALUE("m", m);
|
||||
ADD_KEY_VALUE("n", n);
|
||||
ADD_KEY_VALUE("x_stride", x_stride);
|
||||
ADD_KEY_VALUE("y_stride", y_stride);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_moe_sorting_json(const std::string& json_filename,
|
||||
const std::string& index_prec,
|
||||
const std::string& weight_prec,
|
||||
const std::string& workspace_size,
|
||||
int dispatch_policy,
|
||||
int tokens,
|
||||
int num_experts,
|
||||
int topk,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "moe_sorting")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("index_prec", index_prec);
|
||||
ADD_KEY_VALUE("weight_prec", weight_prec);
|
||||
ADD_KEY_VALUE("workspace_size", workspace_size);
|
||||
ADD_KEY_VALUE("dispatch_policy", dispatch_policy);
|
||||
ADD_KEY_VALUE("tokens", tokens);
|
||||
ADD_KEY_VALUE("num_experts", num_experts);
|
||||
ADD_KEY_VALUE("topk", topk);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_batched_transpose_json(const std::string& json_filename,
|
||||
int N,
|
||||
int C,
|
||||
int H,
|
||||
int W,
|
||||
const std::string& layout_in,
|
||||
const std::string& layout_out,
|
||||
const std::string& prec,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "batched_transpose")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("C", C);
|
||||
ADD_KEY_VALUE("H", H);
|
||||
ADD_KEY_VALUE("W", W);
|
||||
ADD_KEY_VALUE("LayoutIn", layout_in);
|
||||
ADD_KEY_VALUE("LayoutOut", layout_out);
|
||||
ADD_KEY_VALUE("Precision", prec);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_moe_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
int tokens,
|
||||
int hidden_size,
|
||||
int stride,
|
||||
int experts,
|
||||
int topk,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "moe_smoothquant")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec_i", prec_i);
|
||||
ADD_KEY_VALUE("prec_o", prec_o);
|
||||
ADD_KEY_VALUE("tokens", tokens);
|
||||
ADD_KEY_VALUE("hidden_size", hidden_size);
|
||||
ADD_KEY_VALUE("stride", stride);
|
||||
ADD_KEY_VALUE("experts", experts);
|
||||
ADD_KEY_VALUE("topk", topk);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fused_moe_json(const std::string& json_filename,
|
||||
const std::string& api_str,
|
||||
const std::string& prec_str,
|
||||
int tokens,
|
||||
bool is_local_token,
|
||||
int local_tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int hidden_size,
|
||||
int intermediate_size,
|
||||
int stride,
|
||||
int block_m,
|
||||
int activation,
|
||||
bool gate_only,
|
||||
bool fused_quant,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float tb_per_sec,
|
||||
const std::string& kernel_name = "fused_moe")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("api", api_str);
|
||||
ADD_KEY_VALUE("prec", prec_str);
|
||||
ADD_KEY_VALUE("tokens", tokens);
|
||||
if(is_local_token)
|
||||
{
|
||||
ADD_KEY_VALUE("local_tokens", local_tokens);
|
||||
}
|
||||
ADD_KEY_VALUE("experts", experts);
|
||||
ADD_KEY_VALUE("topk", topk);
|
||||
ADD_KEY_VALUE("hidden_size", hidden_size);
|
||||
ADD_KEY_VALUE("intermediate_size", intermediate_size);
|
||||
ADD_KEY_VALUE("stride", stride);
|
||||
ADD_KEY_VALUE("block_m", block_m);
|
||||
ADD_KEY_VALUE("activation", activation);
|
||||
ADD_KEY_VALUE("gate_only", gate_only);
|
||||
ADD_KEY_VALUE("fused_quant", fused_quant);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, (tb_per_sec * 1024.0f))
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
const std::string& mode,
|
||||
const std::string& io_layout,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_qs,
|
||||
int seqlen_ks,
|
||||
int seqlen_kpads,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale_s,
|
||||
float p_drop,
|
||||
bool lse,
|
||||
bool squant,
|
||||
const std::string& bais,
|
||||
const std::string& vlayout,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", prec);
|
||||
ADD_KEY_VALUE("mode", mode);
|
||||
ADD_KEY_VALUE("io_layout", io_layout);
|
||||
ADD_KEY_VALUE("batch", batch);
|
||||
ADD_KEY_VALUE("nhead", nhead);
|
||||
ADD_KEY_VALUE("nhead_k", nhead_k);
|
||||
ADD_KEY_VALUE("seqlen_q", seqlen_qs);
|
||||
ADD_KEY_VALUE("seqlen_k", seqlen_ks);
|
||||
ADD_KEY_VALUE("seqlen_kpads", seqlen_kpads);
|
||||
ADD_KEY_VALUE("hdim_q", hdim_q);
|
||||
ADD_KEY_VALUE("hdim_v", hdim_v);
|
||||
ADD_KEY_VALUE("scale_s", scale_s);
|
||||
ADD_KEY_VALUE("p_drop", p_drop);
|
||||
ADD_KEY_VALUE("lse", lse);
|
||||
ADD_KEY_VALUE("squant", squant);
|
||||
ADD_KEY_VALUE("bias", bais);
|
||||
ADD_KEY_VALUE("vlayout", vlayout);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fmha_bwd_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
const std::string& mode,
|
||||
const std::string& i_perm,
|
||||
const std::string& o_perm,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale,
|
||||
const std::string& bias,
|
||||
bool use_dbias,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
bool deterministic,
|
||||
const std::string& mask,
|
||||
int mask_left,
|
||||
int mask_right,
|
||||
int workspace_size,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_bwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", data_type);
|
||||
ADD_KEY_VALUE("mode", mode);
|
||||
ADD_KEY_VALUE("i_perm", i_perm);
|
||||
ADD_KEY_VALUE("o_perm", o_perm);
|
||||
ADD_KEY_VALUE("batch", batch);
|
||||
ADD_KEY_VALUE("nhead", nhead);
|
||||
ADD_KEY_VALUE("nhead_k", nhead_k);
|
||||
ADD_KEY_VALUE("seqlen_q", seqlen_q);
|
||||
ADD_KEY_VALUE("seqlen_k", seqlen_k);
|
||||
ADD_KEY_VALUE("hdim_q", hdim_q);
|
||||
ADD_KEY_VALUE("hdim_v", hdim_v);
|
||||
ADD_KEY_VALUE("scale", scale);
|
||||
ADD_KEY_VALUE("bias", bias);
|
||||
ADD_KEY_VALUE("use_dbias", use_dbias);
|
||||
ADD_KEY_VALUE("p_drop", p_drop);
|
||||
ADD_KEY_VALUE("s_randval", s_randval);
|
||||
ADD_KEY_VALUE("deterministic", deterministic ? "true" : "false");
|
||||
ADD_KEY_VALUE("mask", mask);
|
||||
ADD_KEY_VALUE("mask_left", mask_left);
|
||||
ADD_KEY_VALUE("mask_right", mask_right);
|
||||
ADD_KEY_VALUE("workspace_size", workspace_size);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
Reference in New Issue
Block a user