Merge commit '7330ec37ee3b8cf2d54630372dfe9e86a893e4f5' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-04 21:11:23 +00:00
parent 5677205f88
commit 7f65be1b3e
51 changed files with 3709 additions and 189 deletions

View File

@@ -27,7 +27,8 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
index_t KPack,
bool TransposeC = false>
constexpr auto BlockGemmPipeline_Selector()
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
@@ -50,7 +51,8 @@ constexpr auto BlockGemmPipeline_Selector()
NPerWmma,
MRepeat,
NRepeat,
KPack>{};
KPack,
TransposeC>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
@@ -72,7 +74,8 @@ constexpr auto BlockGemmPipeline_Selector()
NPerWmma,
MRepeat,
NRepeat,
KPack>{};
KPack,
TransposeC>{};
}
else
{

View File

@@ -277,6 +277,21 @@ struct BlockwiseGemmWmmaops_pipeline_base
"wrong!");
}
// transposed WMMA output C' = B' * A'
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
}
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{

View File

@@ -31,7 +31,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
index_t KPack,
bool TransposeC = false>
struct BlockwiseGemmWmmaops_pipeline_v1
{
};
@@ -53,7 +54,8 @@ template <index_t BlockSize,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
index_t KPack,
bool TransposeC>
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
@@ -72,7 +74,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
NPerWmma,
MRepeat,
NRepeat,
KPack>
KPack,
TransposeC>
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
BDataType,
@@ -90,8 +93,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
NPerWmma,
MRepeat,
NRepeat,
KPack>
KPack,
TransposeC>
{
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
@@ -110,7 +113,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
NPerWmma,
MRepeat,
NRepeat,
KPack>;
KPack,
TransposeC>;
using Base::I0;
using Base::A_K1;
@@ -329,7 +333,8 @@ template <index_t BlockSize,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
index_t KPack,
bool TransposeC>
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
BlockSize,
ADataType,
@@ -348,7 +353,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
NPerWmma,
MRepeat,
NRepeat,
KPack>
KPack,
TransposeC>
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
BDataType,
@@ -366,8 +372,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
NPerWmma,
MRepeat,
NRepeat,
KPack>
KPack,
TransposeC>
{
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
@@ -386,7 +392,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
NPerWmma,
MRepeat,
NRepeat,
KPack>;
KPack,
TransposeC>;
using Base::I0;
using Base::I1;

View File

@@ -31,7 +31,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
index_t KPack,
bool TransposeC = false>
struct BlockwiseGemmWmmaops_pipeline_v3
{
};
@@ -53,7 +54,8 @@ template <index_t BlockSize,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
index_t KPack,
bool TransposeC>
struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
@@ -72,7 +74,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
NPerWmma,
MRepeat,
NRepeat,
KPack>
KPack,
TransposeC>
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
BDataType,
@@ -90,7 +93,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
NPerWmma,
MRepeat,
NRepeat,
KPack>
KPack,
TransposeC>
{
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
@@ -109,7 +113,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
NPerWmma,
MRepeat,
NRepeat,
KPack>;
KPack,
TransposeC>;
using Base::I0;
using Base::A_K1;
@@ -128,6 +133,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
using Base::GetCThreadBuffer;
using Base::
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
using Base::
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
using Base::a_block_desc_k0_m0_m1_m2_k1;
using Base::b_block_desc_k0_n0_n1_n2_k1;
@@ -145,8 +152,21 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
if(BlockHasHotloop(num_loop))
{
return TailNumber::Full;
}
else
{
if(num_loop == 1)
{
return TailNumber::Odd;
}
else
{
return TailNumber::Even;
}
}
}
__device__ static constexpr auto HotLoopScheduler()
@@ -362,12 +382,15 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Global prefetch 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// Global prefetch 2, perform when at least 2 loops exist.
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
// Initialize C
c_thread_buf.Clear();
@@ -379,7 +402,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_sched_barrier(0);
// main body
// Main body, perform when at least 3 loops exist.
if constexpr(HasMainLoop)
{
index_t i = 0;
@@ -448,10 +471,62 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Full)
// Pre-tail, perform when at least 2 loops exist.
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
{
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// No RunRead or MoveSrcSliceWindow here, already finished them all!
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
// Tail, always perform.
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {

View File

@@ -0,0 +1,788 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
arg.p_a_grid + a_batch_offset,
arg.p_b0_grid + b0_batch_offset,
arg.p_b1_grid + b1_batch_offset,
arg.p_c_grid + c_batch_offset,
p_shared,
arg.a_grid_desc,
arg.b0_grid_desc,
arg.b1_grid_desc,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op,
arg.b0_element_op,
arg.acc_element_op,
arg.b1_element_op,
arg.c_element_op,
arg.block_2_ctile_map);
#else
ignore = arg;
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
}
// Computes C = A * B0 * B1
// MN = MK * KL * LN
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <typename ALayout,
typename B0layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t LPerBlock, // Gemm0NPerBlock
ck::index_t KPerBlock, // Gemm0KPerBlock
ck::index_t NPerBlock, // Gemm1NPerBlock
ck::index_t LTilePerBlock, // Gemm1KPerBlock
ck::index_t AK1,
ck::index_t BK1,
ck::index_t L1, // B1K1
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
ck::index_t MRepeat, // Gemm0/1 MWmmaPerWave or Mrepeat
ck::index_t LRepeat, // Gemm0 NWmmaPerWave or Nrepeat
ck::index_t NRepeat, // Gemm1 NWmmaPerWave or Nrepeat
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename B0BlockTransferThreadClusterLengths_K0_L_K1,
typename B0BlockTransferThreadClusterArrangeOrder,
typename B0BlockTransferSrcAccessOrder,
ck::index_t B0BlockTransferSrcVectorDim,
ck::index_t B0BlockTransferSrcScalarPerVector,
ck::index_t B0BlockTransferDstScalarPerVector_K1,
bool B0BlockLdsAddExtraL,
typename B1BlockTransferThreadClusterLengths_L0_N_L1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
ck::index_t B1BlockTransferSrcVectorDim,
ck::index_t B1BlockTransferSrcScalarPerVector,
ck::index_t B1BlockTransferDstScalarPerVector_L1,
bool B1BlockLdsAddExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1>
struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALayout,
B0layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmGemm_Wmma_CShuffleV3;
static constexpr auto I0 = Number<0>{};
// To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal
// to LPerWmma (A.k.a Gemm0 NPerWmma).
static constexpr index_t NPerWmma = LPerWmma;
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
// Transform operator or just not use one at all.
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
Sequence<1, 1, 1, 1, 1>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
GemmSpec,
TensorSpecialization::Default, // ASpec
TensorSpecialization::Default, // B0Spec
TensorSpecialization::Default, // B1Spec
TensorSpecialization::Default>; // CSpec
__host__ __device__ static auto
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
const std::array<index_t, 3>& a_g_m_k_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
Number<AK1>{});
}
__host__ __device__ static auto
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
Number<BK1>{});
}
__host__ __device__ static auto
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
Number<L1>{});
}
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB0,
index_t BatchStrideB1,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA),
BatchStrideB0_(BatchStrideB0),
BatchStrideB1_(BatchStrideB1),
BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB0_;
index_t BatchStrideB1_;
index_t BatchStrideC_;
};
// GridwiseOp
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
// DataType Family
ADataType,
B0DataType,
AccDataType, // Acc0DataType
B1DataType,
AccDataType, // Acc1DataType
CShuffleDataType,
CDataType,
// ElementwiseOp Family
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor
AGridDesc,
B0GridDesc,
B1GridDesc,
CGridDesc_M_N,
// Tiling Family
MPerBlock,
LPerBlock,
KPerBlock,
AK1,
BK1,
NPerBlock,
LTilePerBlock,
L1,
MPerWmma,
LPerWmma,
NPerWmma,
MRepeat,
LRepeat,
NRepeat,
// ThreadCluster Family
BlockSize,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
true,
ABlockLdsAddExtraM,
B0BlockTransferThreadClusterLengths_K0_L_K1,
B0BlockTransferThreadClusterArrangeOrder,
B0BlockTransferSrcAccessOrder,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B0BlockTransferDstScalarPerVector_K1,
true,
B0BlockLdsAddExtraL,
B1BlockTransferThreadClusterLengths_L0_N_L1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_L1,
false,
B1BlockLdsAddExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Transform::matrix_padder.PadN,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
struct RawArg : public BaseArgument
{
using arr3 = std::array<ck::index_t, 3>;
RawArg(const ADataType* p_a_grid_,
const B0DataType* p_b0_grid_,
const B1DataType* p_b1_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t O_,
index_t Batch,
index_t StrideA,
index_t StrideB0,
index_t StrideB1,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB0,
index_t BatchStrideB1,
index_t BatchStrideC,
AElementwiseOperation a_element_op_,
B0ElementwiseOperation b0_element_op_,
AccElementwiseOperation acc_element_op_,
B1ElementwiseOperation b1_element_op_,
CElementwiseOperation c_element_op_)
: p_a_grid{p_a_grid_},
p_b0_grid{p_b0_grid_},
p_b1_grid{p_b1_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
O{O_},
batch_count{Batch},
a_element_op{a_element_op_},
b0_element_op{b0_element_op_},
acc_element_op{acc_element_op_},
b1_element_op{b1_element_op_},
c_element_op{c_element_op_},
compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC}
{
a_g_m_k_lengths = arr3{batch_count, M, K};
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
b0_g_n_k_lengths = arr3{batch_count, N, K};
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
b1_g_o_n_lengths = arr3{batch_count, O, N};
b1_g_o_n_strides =
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
c_g_m_o_lengths = arr3{batch_count, M, O};
c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O]
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
}
// Pointers
const ADataType* p_a_grid;
const B0DataType* p_b0_grid;
const B1DataType* p_b1_grid;
CDataType* p_c_grid;
// Raw Problem Size
index_t M;
index_t N;
index_t K;
index_t O;
index_t batch_count;
arr3 a_g_m_k_lengths;
arr3 a_g_m_k_strides;
arr3 b0_g_n_k_lengths;
arr3 b0_g_n_k_strides;
arr3 b1_g_o_n_lengths;
arr3 b1_g_o_n_strides;
arr3 c_g_m_o_lengths;
arr3 c_g_m_o_strides;
AElementwiseOperation a_element_op;
B0ElementwiseOperation b0_element_op;
AccElementwiseOperation acc_element_op;
B1ElementwiseOperation b1_element_op;
CElementwiseOperation c_element_op;
// Grid descriptors and other mem calculators
AGridDesc a_grid_desc;
B0GridDesc b0_grid_desc;
B1GridDesc b1_grid_desc;
CGridDesc_M_N c_grid_desc_m_n;
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
};
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
{
// Print lambda with env check and printf() style formmating.
const char* curFunc = __func__;
auto print = [&curFunc](const char* format, ...) -> void {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
va_list args;
va_start(args, format);
std::vfprintf(stdout, format, args);
va_end(args);
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
}
};
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
print("DeviceOp: Arch err\n");
return false;
}
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
{
if(ck::is_gfx11_supported())
{
print("DeviceOp: gfx 11 does not support fp8\n");
return false;
}
}
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
print("DeviceOp: Acc0 Type err\n");
return false;
}
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: A layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B layout must be Column\n");
return false;
}
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B1 layout must be Column or Row\n");
return false;
}
if constexpr(!(is_same_v<CLayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: C layout must be Row\n");
return false;
}
// Other padding modes have not been tested and do not get checked individually.
if constexpr(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MNKOPadding)
{
print("Padding mode must be default or MNKO\n");
return false;
}
// Per wmma dimensions not equal to 16 are very untested.
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
{
print("M, L, N per Wmma must be 16\n");
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
arg.b1_grid_desc,
arg.c_grid_desc_m_n,
arg.block_2_ctile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto c_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest =
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
const auto c_stride_lowest = arg.c_g_m_o_strides[2];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
{
return false;
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
}
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::RawArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
const index_t grid_size = arg.batch_count * M0 * N0;
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
constexpr TailNumber tn = tail_number;
const auto kernel =
kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp, GridwiseOp, has_loop, tn>;
return launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
};
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
return 0.0f;
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Even>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Odd>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
return 0.0f;
}
}
else
{
printf("Invalid pipeline version!\n");
return 0.0f;
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b0,
const void* p_b1,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t O,
ck::index_t Batch,
ck::index_t StrideA,
ck::index_t StrideB0,
ck::index_t StrideB1,
ck::index_t StrideC,
ck::index_t BatchStrideA,
ck::index_t BatchStrideB0,
ck::index_t BatchStrideB1,
ck::index_t BatchStrideC,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
template <typename T>
static constexpr const char* DataTypeToString()
{
if constexpr(std::is_same_v<T, float>)
{
return "fp32";
}
else if constexpr(std::is_same_v<T, ck::half_t>)
{
return "fp16";
}
else if constexpr(std::is_same_v<T, ck::bhalf_t>)
{
return "bf16";
}
else if constexpr(std::is_same_v<T, ck::f8_t>)
{
return "fp8";
}
else if constexpr(std::is_same_v<T, ck::bf8_t>)
{
return "bf8";
}
else if constexpr(std::is_same_v<T, int32_t>)
{
return "int32";
}
else if constexpr(std::is_same_v<T, int8_t>)
{
return "int8";
}
else if constexpr(std::is_same_v<T, ck::int4_t>)
{
return "int4";
}
else
{
return "unknown";
}
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceBatchedGemmGemm_Wmma_CShuffleV3"
<< "<"
<< ALayout::name[0]
<< B0layout::name[0]
<< B1Layout::name[0]
<< CLayout::name[0] << ", "
<< "A " << DataTypeToString<ADataType>() << ", "
<< "B0 " << DataTypeToString<B0DataType>() << ", "
<< "B1 " << DataTypeToString<B1DataType>() << ", "
<< "C " << DataTypeToString<CDataType>() << ", "
<< "Acc " << DataTypeToString<AccDataType>() << ", "
<< "Cshuf " << DataTypeToString<CShuffleDataType>() << ", "
<< BlockSize << ", "
<< MPerBlock << ", "
<< LPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< LTilePerBlock << ", "
<< L1 << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseOp::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -243,6 +243,30 @@ inline __host__ __device__ constexpr half_t type_convert_sp<half_t, int>(int x)
return u.fp16;
}
template <>
inline __host__ __device__ constexpr int type_convert_sp<int, f8_t>(f8_t x)
{
union
{
f8_t fp8;
int int32;
} u = {x};
return u.int32;
}
template <>
inline __host__ __device__ constexpr f8_t type_convert_sp<f8_t, int>(int x)
{
union
{
int int32;
f8_t fp8;
} u = {x};
return u.fp8;
}
template <>
inline __host__ __device__ constexpr int type_convert_sp<int, bhalf_t>(bhalf_t x)
{

View File

@@ -0,0 +1,700 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant"
#include "rapidjson/writer.h"
#include "rapidjson/stringbuffer.h"
#include "rapidjson/document.h"
#include "rapidjson/rapidjson.h"
// #include <fstream>
#pragma GCC diagnostic pop
#define START_JSON_DUMP_FILE(file_name) \
std::string file_str(file_name); \
std::ofstream file(file_str); \
if(!file.is_open()) \
{ \
throw std::runtime_error("Could not open file: " + std::string(file_name)); \
} \
rapidjson::StringBuffer s; \
rapidjson::Writer<rapidjson::StringBuffer> writer(s); \
writer.StartObject();
#define END_JSON_DUMP_FILE() \
writer.EndObject(); \
file << s.GetString(); \
file.close(); \
std::cout << "Results written to " << file_str << " successfully" << std::endl;
#define ADD_KEY_VALUE(key, value) add_key_value_pair(writer, key, value);
#define ADD_PERF_TO_JSON(_time, tflops, gbytes) add_perf_to_json(writer, _time, tflops, gbytes);
template <typename T>
void add_key_value_pair(rapidjson::Writer<rapidjson::StringBuffer>& writer,
const char* key,
T value)
{
writer.Key(key);
if constexpr(std::is_same<T, const char*>::value)
{
writer.String(value, static_cast<rapidjson::SizeType>(std::strlen(value)));
}
else if constexpr(std::is_same<T, std::string>::value)
{
writer.String(value.c_str(), static_cast<rapidjson::SizeType>(value.length()));
}
else if constexpr(std::is_floating_point<T>::value)
{
writer.Double(static_cast<double>(value));
}
else if constexpr(std::is_integral<T>::value)
{
writer.Int64(static_cast<int64_t>(value));
}
else
{
static_assert(std::is_same<T, const char*>::value || std::is_floating_point<T>::value ||
std::is_integral<T>::value,
"Unsupported type for JSON serialization");
}
}
static void add_perf_to_json(rapidjson::Writer<rapidjson::StringBuffer>& writer,
float time,
float tflops,
float gbytes)
{
std::string roster("perf");
writer.String(roster.c_str(), static_cast<rapidjson::SizeType>(roster.length()));
writer.StartArray();
writer.StartObject();
add_key_value_pair(writer, "time", time);
add_key_value_pair(writer, "tflops", tflops);
add_key_value_pair(writer, "gbytes", gbytes);
writer.EndObject();
writer.EndArray();
}
// Helper traits to check for static member existence
template <typename T, typename = void>
struct has_warp_tile_members : std::false_type
{
};
template <typename T>
struct has_warp_tile_members<
T,
std::void_t<decltype(T::M_Warp_Tile), decltype(T::N_Warp_Tile), decltype(T::K_Warp_Tile)>>
: std::true_type
{
};
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename GemmConfig,
template <typename>
typename DTypeTraits>
void dump_gemm_json_results(const std::string& json_filename,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
bool persistent,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "gemm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("M", M);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("K", K);
ADD_KEY_VALUE("stride_A", stride_A);
ADD_KEY_VALUE("stride_B", stride_B);
ADD_KEY_VALUE("stride_C", stride_C);
ADD_KEY_VALUE("A_layout", ALayout::name);
ADD_KEY_VALUE("B_layout", BLayout::name);
ADD_KEY_VALUE("C_layout", CLayout::name);
using TraitsADataType = DTypeTraits<ADataType>;
using TraitsBDataType = DTypeTraits<BDataType>;
using TraitsCDataType = DTypeTraits<CDataType>;
ADD_KEY_VALUE("A_type", TraitsADataType::name);
ADD_KEY_VALUE("B_type", TraitsBDataType::name);
ADD_KEY_VALUE("C_type", TraitsCDataType::name);
ADD_KEY_VALUE("structured_sparsity", GemmConfig::UseStructuredSparsity ? "on" : "off");
if constexpr(has_warp_tile_members<GemmConfig>::value)
{
ADD_KEY_VALUE("warp_tile",
std::to_string(GemmConfig::M_Warp_Tile) + "x" +
std::to_string(GemmConfig::N_Warp_Tile) + "x" +
std::to_string(GemmConfig::K_Warp_Tile));
}
ADD_KEY_VALUE("persistent", persistent ? "on" : "off");
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_batched_gemm_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int batch_stride_A,
int batch_stride_B,
int batch_stride_C,
int batch_count,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "batched_gemm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("op_name", op_name);
ADD_KEY_VALUE("M", M);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("K", K);
ADD_KEY_VALUE("stride_A", stride_A);
ADD_KEY_VALUE("stride_B", stride_B);
ADD_KEY_VALUE("stride_C", stride_C);
ADD_KEY_VALUE("batch_stride_A", batch_stride_A);
ADD_KEY_VALUE("batch_stride_B", batch_stride_B);
ADD_KEY_VALUE("batch_stride_C", batch_stride_C);
ADD_KEY_VALUE("batch_count", batch_count);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
template <typename ALayout, typename BLayout, typename CLayout>
void dump_grouped_gemm_json_results(const std::string& json_filename,
const std::string& op_name,
int group_count,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "grouped_gemm")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("op_name", op_name);
ADD_KEY_VALUE("group_count", group_count);
ADD_KEY_VALUE("A_layout", ALayout::name);
ADD_KEY_VALUE("B_layout", BLayout::name);
ADD_KEY_VALUE("C_layout", CLayout::name);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_flatmm_json_results(const std::string& json_filename,
const std::string& datatype,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int kbatch,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "flatmm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("DataType", datatype);
ADD_KEY_VALUE("M", M);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("K", K);
ADD_KEY_VALUE("StrideA", stride_A);
ADD_KEY_VALUE("StrideB", stride_B);
ADD_KEY_VALUE("StrideC", stride_C);
ADD_KEY_VALUE("kbatch", kbatch);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideD0,
int StrideD1,
int StrideE,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "gemm_multi_d_fp16")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("op_name", op_name);
ADD_KEY_VALUE("M", M);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("K", K);
ADD_KEY_VALUE("StrideA", StrideA);
ADD_KEY_VALUE("StrideB", StrideB);
ADD_KEY_VALUE("StrideD0", StrideD0);
ADD_KEY_VALUE("StrideD1", StrideD1);
ADD_KEY_VALUE("StrideE", StrideE);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_elementwise_json_results(const std::string& json_filename,
const std::string& prec,
int grid_size,
int block_size,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "elementwise")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", prec);
ADD_KEY_VALUE("grid_size", grid_size);
ADD_KEY_VALUE("block_size", block_size);
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
const std::string& prec_sm,
const std::string& prec_sy,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "layernorm2d_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec_i", prec_i);
ADD_KEY_VALUE("prec_o", prec_o);
ADD_KEY_VALUE("prec_sm", prec_sm);
ADD_KEY_VALUE("prec_sy", prec_sy);
ADD_KEY_VALUE("m", m);
ADD_KEY_VALUE("n", n);
ADD_KEY_VALUE("x_stride", x_stride);
ADD_KEY_VALUE("xr_stride", xr_stride);
ADD_KEY_VALUE("y_stride", y_stride);
ADD_KEY_VALUE("yr_stride", yr_stride);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
template <typename DataType, template <typename> typename DTypeTraits>
void dump_reduce_json_results(const std::string& json_filename,
int N,
int C,
int H,
int W,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "reduce")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
using Traits = DTypeTraits<DataType>;
ADD_KEY_VALUE("data_type", Traits::name);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("C", C);
ADD_KEY_VALUE("H", H);
ADD_KEY_VALUE("W", W);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_permute_json_results(const std::string& json_filename,
const std::string& data_type,
bool pass,
float ave_time,
float tflop,
float gb_per_sec,
const std::string& kernel_name = "permute")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("data_type", data_type);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_topk_softmax_json(const std::string& json_filename,
const std::string& input_prec,
const std::string& weight_prec,
int tokens,
int experts,
int topk,
int stride_input,
int stride_output,
float ave_time,
float tflop,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "topk_softmax")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("input_prec", input_prec);
ADD_KEY_VALUE("weight_prec", weight_prec);
ADD_KEY_VALUE("tokens", tokens);
ADD_KEY_VALUE("experts", experts);
ADD_KEY_VALUE("topk", topk);
ADD_KEY_VALUE("stride_input", stride_input);
ADD_KEY_VALUE("stride_output", stride_output);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
int use_model_sensitive_rmsnorm,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "rmsnorm2d_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", prec_str);
ADD_KEY_VALUE("m", m);
ADD_KEY_VALUE("n", n);
ADD_KEY_VALUE("x_stride", x_stride);
ADD_KEY_VALUE("xr_stride", xr_stride);
ADD_KEY_VALUE("y_stride", y_stride);
ADD_KEY_VALUE("yr_stride", yr_stride);
ADD_KEY_VALUE("use_model_sensitive_rmsnorm", use_model_sensitive_rmsnorm);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_add_rmsnorm2d_rdquant_fwd_json(
const std::string& json_filename,
const std::string& input_data_type,
const std::string& quantized_data_type,
int m,
int n,
int stride,
float epsilon,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("input_data_type", input_data_type);
ADD_KEY_VALUE("quantized_data_type", quantized_data_type);
ADD_KEY_VALUE("m", m);
ADD_KEY_VALUE("n", n);
ADD_KEY_VALUE("stride", stride);
ADD_KEY_VALUE("epsilon", epsilon);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_smoothquant_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int y_stride,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "smoothquant")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", prec_str);
ADD_KEY_VALUE("m", m);
ADD_KEY_VALUE("n", n);
ADD_KEY_VALUE("x_stride", x_stride);
ADD_KEY_VALUE("y_stride", y_stride);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_moe_sorting_json(const std::string& json_filename,
const std::string& index_prec,
const std::string& weight_prec,
const std::string& workspace_size,
int dispatch_policy,
int tokens,
int num_experts,
int topk,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "moe_sorting")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("index_prec", index_prec);
ADD_KEY_VALUE("weight_prec", weight_prec);
ADD_KEY_VALUE("workspace_size", workspace_size);
ADD_KEY_VALUE("dispatch_policy", dispatch_policy);
ADD_KEY_VALUE("tokens", tokens);
ADD_KEY_VALUE("num_experts", num_experts);
ADD_KEY_VALUE("topk", topk);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_batched_transpose_json(const std::string& json_filename,
int N,
int C,
int H,
int W,
const std::string& layout_in,
const std::string& layout_out,
const std::string& prec,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "batched_transpose")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("C", C);
ADD_KEY_VALUE("H", H);
ADD_KEY_VALUE("W", W);
ADD_KEY_VALUE("LayoutIn", layout_in);
ADD_KEY_VALUE("LayoutOut", layout_out);
ADD_KEY_VALUE("Precision", prec);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_moe_smoothquant_json(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
int tokens,
int hidden_size,
int stride,
int experts,
int topk,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "moe_smoothquant")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec_i", prec_i);
ADD_KEY_VALUE("prec_o", prec_o);
ADD_KEY_VALUE("tokens", tokens);
ADD_KEY_VALUE("hidden_size", hidden_size);
ADD_KEY_VALUE("stride", stride);
ADD_KEY_VALUE("experts", experts);
ADD_KEY_VALUE("topk", topk);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_fused_moe_json(const std::string& json_filename,
const std::string& api_str,
const std::string& prec_str,
int tokens,
bool is_local_token,
int local_tokens,
int experts,
int topk,
int hidden_size,
int intermediate_size,
int stride,
int block_m,
int activation,
bool gate_only,
bool fused_quant,
bool pass,
float ave_time,
float tflops,
float tb_per_sec,
const std::string& kernel_name = "fused_moe")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("api", api_str);
ADD_KEY_VALUE("prec", prec_str);
ADD_KEY_VALUE("tokens", tokens);
if(is_local_token)
{
ADD_KEY_VALUE("local_tokens", local_tokens);
}
ADD_KEY_VALUE("experts", experts);
ADD_KEY_VALUE("topk", topk);
ADD_KEY_VALUE("hidden_size", hidden_size);
ADD_KEY_VALUE("intermediate_size", intermediate_size);
ADD_KEY_VALUE("stride", stride);
ADD_KEY_VALUE("block_m", block_m);
ADD_KEY_VALUE("activation", activation);
ADD_KEY_VALUE("gate_only", gate_only);
ADD_KEY_VALUE("fused_quant", fused_quant);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, (tb_per_sec * 1024.0f))
END_JSON_DUMP_FILE();
}
void dump_fmha_fwd_json_results(const std::string& json_filename,
const std::string& prec,
const std::string& mode,
const std::string& io_layout,
int batch,
int nhead,
int nhead_k,
int seqlen_qs,
int seqlen_ks,
int seqlen_kpads,
int hdim_q,
int hdim_v,
float scale_s,
float p_drop,
bool lse,
bool squant,
const std::string& bais,
const std::string& vlayout,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", prec);
ADD_KEY_VALUE("mode", mode);
ADD_KEY_VALUE("io_layout", io_layout);
ADD_KEY_VALUE("batch", batch);
ADD_KEY_VALUE("nhead", nhead);
ADD_KEY_VALUE("nhead_k", nhead_k);
ADD_KEY_VALUE("seqlen_q", seqlen_qs);
ADD_KEY_VALUE("seqlen_k", seqlen_ks);
ADD_KEY_VALUE("seqlen_kpads", seqlen_kpads);
ADD_KEY_VALUE("hdim_q", hdim_q);
ADD_KEY_VALUE("hdim_v", hdim_v);
ADD_KEY_VALUE("scale_s", scale_s);
ADD_KEY_VALUE("p_drop", p_drop);
ADD_KEY_VALUE("lse", lse);
ADD_KEY_VALUE("squant", squant);
ADD_KEY_VALUE("bias", bais);
ADD_KEY_VALUE("vlayout", vlayout);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_fmha_bwd_json_results(const std::string& json_filename,
const std::string& data_type,
const std::string& mode,
const std::string& i_perm,
const std::string& o_perm,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale,
const std::string& bias,
bool use_dbias,
float p_drop,
bool s_randval,
bool deterministic,
const std::string& mask,
int mask_left,
int mask_right,
int workspace_size,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_bwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", data_type);
ADD_KEY_VALUE("mode", mode);
ADD_KEY_VALUE("i_perm", i_perm);
ADD_KEY_VALUE("o_perm", o_perm);
ADD_KEY_VALUE("batch", batch);
ADD_KEY_VALUE("nhead", nhead);
ADD_KEY_VALUE("nhead_k", nhead_k);
ADD_KEY_VALUE("seqlen_q", seqlen_q);
ADD_KEY_VALUE("seqlen_k", seqlen_k);
ADD_KEY_VALUE("hdim_q", hdim_q);
ADD_KEY_VALUE("hdim_v", hdim_v);
ADD_KEY_VALUE("scale", scale);
ADD_KEY_VALUE("bias", bias);
ADD_KEY_VALUE("use_dbias", use_dbias);
ADD_KEY_VALUE("p_drop", p_drop);
ADD_KEY_VALUE("s_randval", s_randval);
ADD_KEY_VALUE("deterministic", deterministic ? "true" : "false");
ADD_KEY_VALUE("mask", mask);
ADD_KEY_VALUE("mask_left", mask_left);
ADD_KEY_VALUE("mask_right", mask_right);
ADD_KEY_VALUE("workspace_size", workspace_size);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}