Remove code duplications in batched gemm (multi D) gemm (multi D) wmma (#3617)

* Added common struct to enable code reduction in gemm gemm and gemm multi_d gemm multi_d wmma implementation

This file includes all shared components. The (shared between the two implementations) kernel, the pointer offset computation struct, the grid descriptor creator and definitions, the invoker struct and the argument struct.

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>

* Used the common struct in the batched gemm gemm wmma cshuffle v3 implementation

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>

* Used the shared structs in the gemm multiple D gemm multiple D wmma cshuffle v3 implementation

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>

* Boy-scout: IWYU paradigm in the gemm gemm and gemm multiple D gemm multiple D wmma cshuffle v3 implementations

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>

---------

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
This commit is contained in:
chris-tsiaousis-hpc
2026-01-26 19:20:30 +01:00
committed by GitHub
parent de59c0716c
commit 917f35553a
3 changed files with 1173 additions and 1163 deletions

View File

@@ -3,77 +3,21 @@
#pragma once
#include <iostream>
#include <sstream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/utility/tuple.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
arg.p_a_grid + a_batch_offset,
arg.p_b0_grid + b0_batch_offset,
Tuple<>{}, // p_d0s_grid
arg.p_b1_grid + b1_batch_offset,
Tuple<>{}, // p_d1s_grid
arg.p_c_grid + c_batch_offset,
p_shared,
arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
arg.b1_grid_desc,
Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op,
arg.b0_element_op,
arg.acc_element_op,
arg.b1_element_op,
arg.c_element_op,
arg.block_2_ctile_map);
#else
ignore = arg;
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
}
// Computes C = A * B0 * B1
// MN = MK * KL * LN
// ^^^^^^ (Acc0)
@@ -157,88 +101,47 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
// to LPerWmma (A.k.a Gemm0 NPerWmma).
static constexpr index_t NPerWmma = LPerWmma;
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
// Transform operator or just not use one at all.
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
Sequence<1, 1, 1, 1, 1>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
GemmSpec,
TensorSpecialization::Default, // ASpec
TensorSpecialization::Default, // B0Spec
TensorSpecialization::Default, // B1Spec
TensorSpecialization::Default>; // CSpec
__host__ __device__ static auto
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
const std::array<index_t, 3>& a_g_m_k_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
Number<AK1>{});
}
__host__ __device__ static auto
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
Number<BK1>{});
}
__host__ __device__ static auto
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
Number<L1>{});
}
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB0,
index_t BatchStrideB1,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA),
BatchStrideB0_(BatchStrideB0),
BatchStrideB1_(BatchStrideB1),
BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB0_;
index_t BatchStrideB1_;
index_t BatchStrideC_;
};
using DeviceGemmGemmCommonBase =
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
GemmSpec,
ALayout,
B0layout,
Tuple<>, // D0sLayout
B1Layout,
Tuple<>, // D1sLayout
CLayout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
CDataType,
Tuple<>, // D0sDataType
Tuple<>, // D1sDataType
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector
CShuffleBlockTransferScalarPerVector_NPerBlock,
false>; // IsMultiD
// GridwiseOp
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
@@ -260,12 +163,12 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor
AGridDesc,
B0GridDesc,
typename DeviceGemmGemmCommonBase::AGridDesc,
typename DeviceGemmGemmCommonBase::B0GridDesc,
Tuple<>, // Ds0GridDesc
B1GridDesc,
typename DeviceGemmGemmCommonBase::B1GridDesc,
Tuple<>, // Ds1GridDesc
CGridDesc_M_N,
typename DeviceGemmGemmCommonBase::CGridDesc_M_N,
// Tiling Family
MPerBlock,
LPerBlock,
@@ -312,339 +215,67 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Transform::matrix_padder.PadN,
DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
struct RawArg : public BaseArgument
using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg<
DeviceOp,
GemmSpec,
ALayout,
B0layout,
Tuple<>, // D0sLayout
B1Layout,
Tuple<>, // D1sLayout
CLayout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
CDataType,
Tuple<>, // D0sDataType,
Tuple<>, // D1sDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector
CShuffleBlockTransferScalarPerVector_NPerBlock,
false>; // IsMultiD
// Invoker
using Invoker = typename DeviceGemmGemmCommon::Invoker;
// Argument
using Argument = typename DeviceGemmGemmCommon::Argument;
static bool IsSupportedArgument(const Argument& arg)
{
using arr3 = std::array<ck::index_t, 3>;
RawArg(const ADataType* p_a_grid_,
const B0DataType* p_b0_grid_,
const B1DataType* p_b1_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t O_,
index_t Batch,
index_t StrideA,
index_t StrideB0,
index_t StrideB1,
index_t StrideC,
index_t BatchStrideA,
index_t BatchStrideB0,
index_t BatchStrideB1,
index_t BatchStrideC,
AElementwiseOperation a_element_op_,
B0ElementwiseOperation b0_element_op_,
AccElementwiseOperation acc_element_op_,
B1ElementwiseOperation b1_element_op_,
CElementwiseOperation c_element_op_)
: p_a_grid{p_a_grid_},
p_b0_grid{p_b0_grid_},
p_b1_grid{p_b1_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
O{O_},
batch_count{Batch},
a_element_op{a_element_op_},
b0_element_op{b0_element_op_},
acc_element_op{acc_element_op_},
b1_element_op{b1_element_op_},
c_element_op{c_element_op_},
compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC}
{
a_g_m_k_lengths = arr3{batch_count, M, K};
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
b0_g_n_k_lengths = arr3{batch_count, N, K};
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
b1_g_o_n_lengths = arr3{batch_count, O, N};
b1_g_o_n_strides =
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
c_g_m_o_lengths = arr3{batch_count, M, O};
c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O]
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1);
}
// Pointers
const ADataType* p_a_grid;
const B0DataType* p_b0_grid;
const B1DataType* p_b1_grid;
CDataType* p_c_grid;
// Raw Problem Size
index_t M;
index_t N;
index_t K;
index_t O;
index_t batch_count;
arr3 a_g_m_k_lengths;
arr3 a_g_m_k_strides;
arr3 b0_g_n_k_lengths;
arr3 b0_g_n_k_strides;
arr3 b1_g_o_n_lengths;
arr3 b1_g_o_n_strides;
arr3 c_g_m_o_lengths;
arr3 c_g_m_o_strides;
AElementwiseOperation a_element_op;
B0ElementwiseOperation b0_element_op;
AccElementwiseOperation acc_element_op;
B1ElementwiseOperation b1_element_op;
CElementwiseOperation c_element_op;
// Grid descriptors and other mem calculators
AGridDesc a_grid_desc;
B0GridDesc b0_grid_desc;
B1GridDesc b1_grid_desc;
CGridDesc_M_N c_grid_desc_m_n;
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
};
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
{
// Print lambda with env check and printf() style formmating.
const char* curFunc = __func__;
auto print = [&curFunc](const char* format, ...) -> void {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
va_list args;
va_start(args, format);
std::vfprintf(stdout, format, args);
va_end(args);
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
}
};
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
print("DeviceOp: Arch err\n");
return false;
}
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
{
if(ck::is_gfx11_supported())
{
print("DeviceOp: gfx 11 does not support fp8\n");
return false;
}
}
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
print("DeviceOp: Acc0 Type err\n");
return false;
}
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: A layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B layout must be Column\n");
return false;
}
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B1 layout must be Column or Row\n");
return false;
}
if constexpr(!(is_same_v<CLayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: C layout must be Row\n");
return false;
}
// Other padding modes have not been tested and do not get checked individually.
if constexpr(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MNKOPadding)
{
print("Padding mode must be default or MNKO\n");
return false;
}
// Per wmma dimensions not equal to 16 are very untested.
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
{
print("M, L, N per Wmma must be 16\n");
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{},
arg.b1_grid_desc,
Tuple<>{},
arg.c_grid_desc_m_n,
arg.block_2_ctile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto c_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest =
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
const auto c_stride_lowest = arg.c_g_m_o_strides[2];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
{
return false;
}
return true;
return DeviceGemmGemmCommon::IsSupportedArgument(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::RawArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
const index_t grid_size = arg.batch_count * M0 * N0;
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
constexpr TailNumber tn = tail_number;
const auto kernel =
kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp, GridwiseOp, has_loop, tn>;
return launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
};
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
return 0.0f;
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Even>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Odd>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
return 0.0f;
}
}
else
{
printf("Invalid pipeline version!\n");
return 0.0f;
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b0,
@@ -669,28 +300,39 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideB1,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
std::array<const void*, DeviceGemmGemmCommonBase::NumD0Tensor> p_d0_grid{};
std::array<const void*, DeviceGemmGemmCommonBase::NumD1Tensor> p_d1_grid{};
std::array<index_t, DeviceGemmGemmCommonBase::NumD0Tensor> StrideD0s{}, BatchStrideD0s{};
std::array<index_t, DeviceGemmGemmCommonBase::NumD1Tensor> StrideD1s, BatchStrideD1s{};
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
p_d0_grid,
static_cast<const B1DataType*>(p_b1),
p_d1_grid,
static_cast<CDataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideD0s,
StrideB1,
StrideD1s,
StrideC,
BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideC,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
}
static auto MakeInvoker() { return Invoker{}; }

View File

@@ -0,0 +1,902 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdint>
#include <iostream>
#include <cstdarg>
#include <type_traits>
#include <utility>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include "ck/utility/integral_constant.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DeviceOp,
typename GridwiseOp,
bool HasMainKBlockLoop,
TailNumber TailNum,
bool IsMultiD>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::Argument arg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_e1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCE1BasePtr(g_idx)));
auto [p_d0s_grid, p_d1s_grid] = [&]() {
if constexpr(IsMultiD)
{
auto create_grid = [](auto NumTensor, auto func, auto& arg_grid, auto&& grid_pointer) {
static_for<0, decltype(NumTensor)::value, 1>{}([&](auto In) {
const long_index_t batch_offset = __builtin_amdgcn_readfirstlane(func(In));
grid_pointer(In) = arg_grid(In) + batch_offset;
});
return std::move(grid_pointer);
};
auto get_d0_base_ptr = [&arg, &g_idx](auto d_idx) {
return arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, d_idx);
};
auto get_d1_base_ptr = [&arg, &g_idx](auto d_idx) {
return arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, d_idx);
};
auto d0s_grid = create_grid(ck::integral_constant<ck::index_t, DeviceOp::NumD0Tensor>{},
get_d0_base_ptr,
arg.p_d0s_grid,
GridwiseOp::MakeD0sGridPointer());
auto d1s_grid = create_grid(ck::integral_constant<ck::index_t, DeviceOp::NumD1Tensor>{},
get_d1_base_ptr,
arg.p_d1s_grid,
GridwiseOp::MakeD1sGridPointer());
return std::make_pair(d0s_grid, d1s_grid);
}
else
{
return std::make_pair(Tuple<>{}, Tuple<>{});
}
}();
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
arg.p_a_grid + a_batch_offset,
arg.p_b0_grid + b0_batch_offset,
p_d0s_grid,
arg.p_b1_grid + b1_batch_offset,
p_d1s_grid,
arg.p_c_e1_grid + c_e1_batch_offset,
p_shared,
arg.a_grid_desc,
arg.b0_grid_desc,
arg.d0s_grid_desc,
arg.b1_grid_desc,
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
arg.c_e1_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op,
arg.b0_element_op,
arg.acc_element_op,
arg.b1_element_op,
arg.cde1_element_op,
arg.block_2_etile_map);
#else
ignore = arg;
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
}
template <typename DeviceOp,
GemmSpecialization GemmSpec,
typename ALayout,
typename B0layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename CE1Layout,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t LPerBlock, // Gemm0NPerBlock
ck::index_t KPerBlock, // Gemm0KPerBlock
ck::index_t NPerBlock, // Gemm1NPerBlock
typename ADataType,
typename B0DataType,
typename B1DataType,
typename AccDataType,
typename CE1DataType,
typename D0sDataType,
typename D1sDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t L1, // B1K1
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
BlockGemmPipelineVersion BlkGemmPipelineVer,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t B0BlockTransferSrcVectorDim,
ck::index_t B0BlockTransferSrcScalarPerVector,
ck::index_t B1BlockTransferSrcVectorDim,
ck::index_t B1BlockTransferSrcScalarPerVector,
ck::index_t CDE0BlockTransferSrcScalarPerVector,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool IsMultiD = false>
struct DeviceGemmGemm_Wmma_CShuffleV3_Common
{
static constexpr ck::index_t NumD0Tensor = []() {
if constexpr(IsMultiD)
{
return DeviceOp::NumD0Tensor;
}
return 0;
}();
static constexpr ck::index_t NumD1Tensor = []() {
if constexpr(IsMultiD)
{
return DeviceOp::NumD1Tensor;
}
return 0;
}();
struct GridDescriptorCreator
{
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
// Transform operator or just not use one at all.
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
Sequence<1, 1, 1, 1, 1>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
GemmSpec,
TensorSpecialization::Default, // ASpec
TensorSpecialization::Default, // B0Spec
TensorSpecialization::Default, // B1Spec
TensorSpecialization::Default>; // CSpec
__host__ __device__ static auto
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
const std::array<index_t, 3>& a_g_m_k_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
Number<AK1>{});
}
__host__ __device__ static auto
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
Number<BK1>{});
}
__host__ __device__ static auto
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
Number<L1>{});
}
__host__ __device__ static auto
MakeD0GridDescriptor(const std::array<index_t, 3>& d0_g_m_n_lengths_vec,
const std::array<index_t, 3>& d0_g_m_n_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec);
}
__host__ __device__ static auto MakeD0sGridDescriptor(
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_lengths_vec,
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_strides_vec)
{
return generate_tuple(
[&](auto i) {
return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]);
},
Number<NumD0Tensor>{});
}
__host__ __device__ static auto MakeD1sGridDescriptor(
const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_lengths_vec,
const std::array<std::array<index_t, 3>, NumD1Tensor>& d1_g_m_o_strides_vec)
{
return generate_tuple(
[&](auto i) {
return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]);
},
Number<NumD1Tensor>{});
}
__host__ __device__ static auto
MakeE1GridDescriptor(const std::array<index_t, 3>& e1_g_m_n_lengths_vec,
const std::array<index_t, 3>& e1_g_m_n_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec);
}
};
using AGridDesc = decltype(GridDescriptorCreator::MakeAGridDescriptor({}, {}));
using B0GridDesc = decltype(GridDescriptorCreator::MakeB0GridDescriptor({}, {}));
using D0sGridDesc =
remove_cvref_t<decltype(GridDescriptorCreator::MakeD0sGridDescriptor({}, {}))>;
using B1GridDesc = decltype(GridDescriptorCreator::MakeB1GridDescriptor({}, {}));
using D1sGridDesc =
remove_cvref_t<decltype(GridDescriptorCreator::MakeD1sGridDescriptor({}, {}))>;
using E1GridDesc = decltype(GridDescriptorCreator::MakeE1GridDescriptor({}, {}));
using CGridDesc_M_N =
decltype(GridDescriptorCreator::Transform::MakeCGridDescriptor_M_N({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB0,
index_t BatchStrideB1,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA),
BatchStrideB0_(BatchStrideB0),
BatchStrideB1_(BatchStrideB1),
BatchStrideC_E1_(BatchStrideC)
{
}
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1)
: BatchStrideA_(BatchStrideA0),
BatchStrideB0_(BatchStrideB0),
BatchStrideD0s_(BatchStrideD0s),
BatchStrideB1_(BatchStrideB1),
BatchStrideD1s_(BatchStrideD1s),
BatchStrideC_E1_(BatchStrideE1)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetCE1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_E1_);
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d0_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d0_idx]);
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx,
Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB0_;
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
index_t BatchStrideB1_;
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
index_t BatchStrideC_E1_;
};
};
template <typename DeviceOp,
GemmSpecialization GemmSpec,
typename ALayout,
typename B0layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename CE1Layout,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t LPerBlock, // Gemm0NPerBlock
ck::index_t KPerBlock, // Gemm0KPerBlock
ck::index_t NPerBlock, // Gemm1NPerBlock
typename ADataType,
typename B0DataType,
typename B1DataType,
typename AccDataType,
typename CE1DataType,
typename D0sDataType,
typename D1sDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CDE1ElementwiseOperation,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t L1, // B1K1
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
BlockGemmPipelineVersion BlkGemmPipelineVer,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t B0BlockTransferSrcVectorDim,
ck::index_t B0BlockTransferSrcScalarPerVector,
ck::index_t B1BlockTransferSrcVectorDim,
ck::index_t B1BlockTransferSrcScalarPerVector,
ck::index_t CDE0BlockTransferSrcScalarPerVector,
ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool IsMultiD = false>
struct DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg
{
using GridwiseGemm = typename DeviceOp::GridwiseOp;
using Common =
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
GemmSpec,
ALayout,
B0layout,
D0sLayout,
B1Layout,
D1sLayout,
CE1Layout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
CE1DataType,
D0sDataType,
D1sDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
CDE0BlockTransferSrcScalarPerVector,
CShuffleBlockTransferScalarPerVector_NPerBlock,
IsMultiD>;
static constexpr auto NumD0Tensor = Common::NumD0Tensor;
static constexpr auto NumD1Tensor = Common::NumD1Tensor;
struct Argument : public BaseArgument
{
using arr3 = std::array<ck::index_t, 3>;
Argument(const ADataType* p_a_grid_,
const B0DataType* p_b0_grid_,
std::array<const void*, NumD0Tensor> p_d0s_grid_,
const B1DataType* p_b1_grid_,
std::array<const void*, NumD1Tensor> p_d1s_grid_,
CE1DataType* p_e1_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t O_,
index_t Batch,
index_t StrideA,
index_t StrideB0,
std::array<index_t, NumD0Tensor> StrideD0s,
index_t StrideB1,
std::array<index_t, NumD1Tensor> StrideD1s,
index_t StrideE1,
index_t BatchStrideA,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1,
AElementwiseOperation a_element_op_,
B0ElementwiseOperation b0_element_op_,
AccElementwiseOperation acc_element_op_,
B1ElementwiseOperation b1_element_op_,
CDE1ElementwiseOperation cde1_element_op_)
: p_a_grid{p_a_grid_},
p_b0_grid{p_b0_grid_},
p_d0s_grid{},
p_b1_grid{p_b1_grid_},
p_d1s_grid{},
p_c_e1_grid{p_e1_grid_},
M{M_},
N{N_},
K{K_},
O{O_},
batch_count{Batch},
a_element_op{a_element_op_},
b0_element_op{b0_element_op_},
acc_element_op{acc_element_op_},
b1_element_op{b1_element_op_},
cde1_element_op{cde1_element_op_},
compute_base_ptr_of_batch{BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1}
{
a_g_m_k_lengths = arr3{batch_count, M, K};
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
b0_g_n_k_lengths = arr3{batch_count, N, K};
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
b1_g_o_n_lengths = arr3{batch_count, O, N};
b1_g_o_n_strides =
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
e1_g_m_o_lengths = arr3{batch_count, M, O};
e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O]
a_grid_desc = Common::GridDescriptorCreator::MakeAGridDescriptor(a_g_m_k_lengths,
a_g_m_k_strides);
b0_grid_desc = Common::GridDescriptorCreator::MakeB0GridDescriptor(b0_g_n_k_lengths,
b0_g_n_k_strides);
b1_grid_desc = Common::GridDescriptorCreator::MakeB1GridDescriptor(b1_g_o_n_lengths,
b1_g_o_n_strides);
c_e1_grid_desc_m_n = Common::GridDescriptorCreator::MakeE1GridDescriptor(
e1_g_m_o_lengths, e1_g_m_o_strides);
c_e1_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_e1_grid_desc_m_n);
block_2_etile_map = GridwiseGemm::MakeDefaultBlock2ETileMap(c_e1_grid_desc_m_n, 1, 1);
if constexpr(IsMultiD)
{
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
// D0s layout [batch_count, M, N]
d0s_g_m_n_lengths[i] = arr3{batch_count, M, N};
d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1};
// D0 pointer
p_d0s_grid(i) = static_cast<const D0DataType*>(p_d0s_grid_[i]);
});
// D0 desc
d0s_grid_desc = Common::GridDescriptorCreator::MakeD0sGridDescriptor(
d0s_g_m_n_lengths, d0s_g_m_n_strides);
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
// D1s layout [batch_count, M, O]
d1s_g_m_o_lengths[i] = arr3{batch_count, M, O};
d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1};
// D1 pointer
p_d1s_grid(i) = static_cast<const D1DataType*>(p_d1s_grid_[i]);
});
// D1 desc
d1s_grid_desc = Common::GridDescriptorCreator::MakeD1sGridDescriptor(
d1s_g_m_o_lengths, d1s_g_m_o_strides);
d1s_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d1s_grid_desc);
}
}
// Pointers
const ADataType* p_a_grid;
const B0DataType* p_b0_grid;
typename GridwiseGemm::D0sGridPointer p_d0s_grid;
const B1DataType* p_b1_grid;
typename GridwiseGemm::D1sGridPointer p_d1s_grid;
CE1DataType* p_c_e1_grid;
// Raw Problem Size
index_t M;
index_t N;
index_t K;
index_t O;
index_t batch_count;
arr3 a_g_m_k_lengths;
arr3 a_g_m_k_strides;
arr3 b0_g_n_k_lengths;
arr3 b0_g_n_k_strides;
std::array<arr3, NumD0Tensor> d0s_g_m_n_lengths;
std::array<arr3, NumD0Tensor> d0s_g_m_n_strides;
arr3 b1_g_o_n_lengths;
arr3 b1_g_o_n_strides;
std::array<arr3, NumD1Tensor> d1s_g_m_o_lengths;
std::array<arr3, NumD1Tensor> d1s_g_m_o_strides;
arr3 e1_g_m_o_lengths;
arr3 e1_g_m_o_strides;
AElementwiseOperation a_element_op;
B0ElementwiseOperation b0_element_op;
AccElementwiseOperation acc_element_op;
B1ElementwiseOperation b1_element_op;
CDE1ElementwiseOperation cde1_element_op;
// Grid descriptors and other mem calculators
typename Common::AGridDesc a_grid_desc;
typename Common::B0GridDesc b0_grid_desc;
std::conditional_t<IsMultiD, typename Common::D0sGridDesc, Tuple<>> d0s_grid_desc;
typename Common::B1GridDesc b1_grid_desc;
typename Common::D1sGridDesc d1s_grid_desc;
std::conditional_t<
IsMultiD,
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
Tuple<>>
d1s_grid_desc_mblock_mperblock_nblock_nperblock;
std::conditional_t<IsMultiD, typename Common::E1GridDesc, typename Common::CGridDesc_M_N>
c_e1_grid_desc_m_n;
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_e1_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map;
typename Common::ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
};
/// @brief Helper structure responsible for kernel invocation.
///
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
/// kernel function. It usually determines the launched grid size prepares kernel
/// arguments as well as perform specific kernel configuration selection based on
/// runtime arguments.
///
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
const index_t grid_size = arg.batch_count * M0 * N0;
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
constexpr TailNumber tail_num = decltype(tail_number)::value;
const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp,
GridwiseGemm,
has_loop,
tail_num,
IsMultiD>;
return launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
};
bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(arg.K);
TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.K);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
return 0.0f;
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Even>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Odd>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
return 0.0f;
}
}
else
{
printf("Invalid pipeline version!\n");
return 0.0f;
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
// check if DsLayout is supported
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
static constexpr bool CheckDLayout()
{
bool valid = true;
// iterate over DLayout tuple
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// if RefLayout and DLayout are same, keep valid true, otherwise false
valid = valid && is_same_v<RefLayout, DLayout>;
});
return valid;
}
static bool IsSupportedArgument(const Argument& arg)
{
// Print lambda with env check and printf() style formmating.
const char* curFunc = __func__;
auto print = [&curFunc](const char* format, ...) -> void {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
va_list args;
va_start(args, format);
std::vfprintf(stdout, format, args);
va_end(args);
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
}
};
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
print("DeviceOp: Arch err\n");
return false;
}
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
{
if(ck::is_gfx11_supported())
{
print("DeviceOp: gfx 11 does not support fp8\n");
return false;
}
}
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
print("DeviceOp: Acc0 Type err\n");
return false;
}
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: A layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B1 layout must be Column or Row\n");
return false;
}
if constexpr(!(is_same_v<CE1Layout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: C layout must be Row\n");
return false;
}
// Other padding modes have not been tested and do not get checked individually.
if constexpr(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MNKOPadding)
{
print("Padding mode must be default or MNKO\n");
return false;
}
// Per wmma dimensions not equal to 16 are very untested.
if constexpr(MPerWmma != 16 || LPerWmma != 16 || DeviceOp::NPerWmma != 16)
{
print("M, L, N per Wmma must be 16\n");
return false;
}
if constexpr(IsMultiD)
{
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B0 layout must be Column\n");
return false;
}
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>()))
{
print("DeviceOp: All D0s layout must be Row\n");
return false;
}
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>()))
{
print("DeviceOp: All D1s layout must be Row\n");
return false;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
arg.d0s_grid_desc,
arg.b1_grid_desc,
arg.d1s_grid_desc,
arg.c_e1_grid_desc_m_n,
arg.block_2_etile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto cde1_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2
? arg.b0_g_n_k_strides[2]
: arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? arg.b1_g_o_n_strides[2]
: arg.b1_g_o_n_strides[1];
const auto e1_stride_lowest = arg.e1_g_m_o_strides[2];
// NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major
// and the lowest dimension stride is hardcoded to 1
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
e1_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
}
else
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
Tuple<>{},
arg.b1_grid_desc,
Tuple<>{},
arg.c_e1_grid_desc_m_n,
arg.block_2_etile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto c_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2
? arg.b0_g_n_k_strides[2]
: arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? arg.b1_g_o_n_strides[2]
: arg.b1_g_o_n_strides[1];
const auto c_stride_lowest = arg.e1_g_m_o_strides[2];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
{
return false;
}
return true;
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -3,91 +3,20 @@
#pragma once
#include <iostream>
#include <sstream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t e1_batch_offset =
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetE1BasePtr(g_idx)));
auto p_d0s_grid = GridwiseOp::MakeD0sGridPointer();
auto p_d1s_grid = GridwiseOp::MakeD1sGridPointer();
static_for<0, DeviceOp::NumD0Tensor, 1>{}([&](auto In) {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
p_d0s_grid(In) = arg.p_d0s_grid(In) + d0_batch_offset;
});
static_for<0, DeviceOp::NumD1Tensor, 1>{}([&](auto In) {
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In)));
p_d1s_grid(In) = arg.p_d1s_grid(In) + d1_batch_offset;
});
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
arg.p_a_grid + a_batch_offset,
arg.p_b0_grid + b0_batch_offset,
p_d0s_grid,
arg.p_b1_grid + b1_batch_offset,
p_d1s_grid,
arg.p_e1_grid + e1_batch_offset,
p_shared,
arg.a_grid_desc,
arg.b0_grid_desc,
arg.d0s_grid_desc,
arg.b1_grid_desc,
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock,
arg.e1_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op,
arg.b0_element_op,
arg.acc_element_op,
arg.b1_element_op,
arg.cde1_element_op,
arg.block_2_etile_map);
#else
ignore = arg;
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
}
// Computes:
// Acc = Acc_Op(A_Op(A) * B0_Op(B0), D0_0, D0_1, ...)
// E = CDE1_Op(Acc_Op(Acc0) * B1_Op(B1), D1_0, D1_1, ...)
@@ -184,151 +113,51 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
static constexpr index_t NumD0Tensor = D0sDataType::Size();
static constexpr index_t NumD1Tensor = D1sDataType::Size();
static constexpr auto I0 = Number<0>{};
// To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal
// to LPerWmma (A.k.a Gemm0 NPerWmma).
static constexpr index_t NPerWmma = LPerWmma;
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
// Transform operator or just not use one at all.
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
Sequence<1, 1, 1, 1, 1>,
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
GemmSpec,
TensorSpecialization::Default, // ASpec
TensorSpecialization::Default, // B0Spec
TensorSpecialization::Default, // B1Spec
TensorSpecialization::Default>; // CSpec
__host__ __device__ static auto
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
const std::array<index_t, 3>& a_g_m_k_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
Number<AK1>{});
}
__host__ __device__ static auto
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
Number<BK1>{});
}
__host__ __device__ static auto
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
Number<L1>{});
}
__host__ __device__ static auto
MakeD0GridDescriptor(const std::array<index_t, 3>& d0_g_m_n_lengths_vec,
const std::array<index_t, 3>& d0_g_m_n_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec);
}
__host__ __device__ static auto MakeD0sGridDescriptor(
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_lengths_vec,
const std::array<std::array<index_t, 3>, NumD0Tensor>& d0_g_m_n_strides_vec)
{
return generate_tuple(
[&](auto i) {
return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]);
},
Number<NumD0Tensor>{});
}
__host__ __device__ static auto MakeD1sGridDescriptor(
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_lengths_vec,
const std::array<std::array<index_t, 3>, NumD0Tensor>& d1_g_m_o_strides_vec)
{
return generate_tuple(
[&](auto i) {
return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]);
},
Number<NumD1Tensor>{});
}
__host__ __device__ static auto
MakeE1GridDescriptor(const std::array<index_t, 3>& e1_g_m_n_lengths_vec,
const std::array<index_t, 3>& e1_g_m_n_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec);
}
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
using D0sGridDesc = remove_cvref_t<decltype(MakeD0sGridDescriptor({}, {}))>;
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
using D1sGridDesc = remove_cvref_t<decltype(MakeD1sGridDescriptor({}, {}))>;
using E1GridDesc = decltype(MakeE1GridDescriptor({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1)
: BatchStrideA0_(BatchStrideA0),
BatchStrideB0_(BatchStrideB0),
BatchStrideD0s_(BatchStrideD0s),
BatchStrideB1_(BatchStrideB1),
BatchStrideD1s_(BatchStrideD1s),
BatchStrideE1_(BatchStrideE1)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA0_);
}
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
}
template <index_t I>
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d1_idx]);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetE1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE1_);
}
template <index_t I>
__host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number<I> d1_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
}
private:
index_t BatchStrideA0_;
index_t BatchStrideB0_;
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
index_t BatchStrideB1_;
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
index_t BatchStrideE1_;
};
using DeviceGemmGemmCommonBase =
DeviceGemmGemm_Wmma_CShuffleV3_Common<DeviceOp,
GemmSpec,
ALayout,
B0layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
E1DataType,
D0sDataType,
D1sDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
CDE0BlockTransferSrcScalarPerVector,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>; // IsMultiD
// GridwiseOp
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
@@ -350,12 +179,12 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
CDE1ElementwiseOperation,
InMemoryDataOperationEnum::Set,
// InMemory Data Descriptor
AGridDesc,
B0GridDesc,
D0sGridDesc,
B1GridDesc,
D1sGridDesc,
E1GridDesc,
typename DeviceGemmGemmCommonBase::AGridDesc,
typename DeviceGemmGemmCommonBase::B0GridDesc,
typename DeviceGemmGemmCommonBase::D0sGridDesc,
typename DeviceGemmGemmCommonBase::B1GridDesc,
typename DeviceGemmGemmCommonBase::D1sGridDesc,
typename DeviceGemmGemmCommonBase::E1GridDesc,
// Tiling Family
MPerBlock,
LPerBlock,
@@ -402,430 +231,67 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Transform::matrix_padder.PadN,
DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN,
BlkGemmPipeSched,
BlkGemmPipelineVer>;
struct RawArg : public BaseArgument
using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg<
DeviceOp,
GemmSpec,
ALayout,
B0layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
BlockSize,
MPerBlock,
LPerBlock,
KPerBlock,
NPerBlock,
ADataType,
B0DataType,
B1DataType,
AccDataType,
E1DataType,
D0sDataType,
D1sDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CDE1ElementwiseOperation,
AK1,
BK1,
L1,
MPerWmma,
LPerWmma,
BlkGemmPipelineVer,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
B0BlockTransferSrcVectorDim,
B0BlockTransferSrcScalarPerVector,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
CDE0BlockTransferSrcScalarPerVector,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>; // IsMultiD
// Invoker
using Invoker = typename DeviceGemmGemmCommon::Invoker;
// Argument
using Argument = typename DeviceGemmGemmCommon::Argument;
static bool IsSupportedArgument(const Argument& arg)
{
using arr3 = std::array<ck::index_t, 3>;
RawArg(const ADataType* p_a_grid_,
const B0DataType* p_b0_grid_,
std::array<const void*, NumD0Tensor> p_d0s_grid_,
const B1DataType* p_b1_grid_,
std::array<const void*, NumD1Tensor> p_d1s_grid_,
E1DataType* p_e1_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t O_,
index_t Batch,
index_t StrideA,
index_t StrideB0,
std::array<index_t, NumD0Tensor> StrideD0s,
index_t StrideB1,
std::array<index_t, NumD1Tensor> StrideD1s,
index_t StrideE1,
index_t BatchStrideA,
index_t BatchStrideB0,
std::array<index_t, NumD0Tensor> BatchStrideD0s,
index_t BatchStrideB1,
std::array<index_t, NumD1Tensor> BatchStrideD1s,
index_t BatchStrideE1,
AElementwiseOperation a_element_op_,
B0ElementwiseOperation b0_element_op_,
AccElementwiseOperation acc_element_op_,
B1ElementwiseOperation b1_element_op_,
CDE1ElementwiseOperation cde1_element_op_)
: p_a_grid{p_a_grid_},
p_b0_grid{p_b0_grid_},
p_d0s_grid{},
p_b1_grid{p_b1_grid_},
p_d1s_grid{},
p_e1_grid{p_e1_grid_},
M{M_},
N{N_},
K{K_},
O{O_},
batch_count{Batch},
a_element_op{a_element_op_},
b0_element_op{b0_element_op_},
acc_element_op{acc_element_op_},
b1_element_op{b1_element_op_},
cde1_element_op{cde1_element_op_},
compute_base_ptr_of_batch{BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1}
{
a_g_m_k_lengths = arr3{batch_count, M, K};
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
b0_g_n_k_lengths = arr3{batch_count, N, K};
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
b1_g_o_n_lengths = arr3{batch_count, O, N};
b1_g_o_n_strides =
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
e1_g_m_o_lengths = arr3{batch_count, M, O};
e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O]
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
e1_grid_desc_m_n = MakeE1GridDescriptor(e1_g_m_o_lengths, e1_g_m_o_strides);
e1_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e1_grid_desc_m_n);
block_2_etile_map = GridwiseOp::MakeDefaultBlock2ETileMap(e1_grid_desc_m_n, 1, 1);
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
// D0s layout [batch_count, M, N]
d0s_g_m_n_lengths[i] = arr3{batch_count, M, N};
d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1};
// D0 pointer
p_d0s_grid(i) = static_cast<const D0DataType*>(p_d0s_grid_[i]);
// D0 desc
d0s_grid_desc(i) = MakeD0GridDescriptor(d0s_g_m_n_lengths[i], d0s_g_m_n_strides[i]);
});
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
// D1s layout [batch_count, M, O]
d1s_g_m_o_lengths[i] = arr3{batch_count, M, O};
d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1};
// D1 pointer
p_d1s_grid(i) = static_cast<const D1DataType*>(p_d1s_grid_[i]);
// D1 desc
d1s_grid_desc(i) = MakeE1GridDescriptor(d1s_g_m_o_lengths[i], d1s_g_m_o_strides[i]);
});
d1s_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(d1s_grid_desc);
}
// Pointers
const ADataType* p_a_grid;
const B0DataType* p_b0_grid;
typename GridwiseOp::D0sGridPointer p_d0s_grid;
const B1DataType* p_b1_grid;
typename GridwiseOp::D1sGridPointer p_d1s_grid;
E1DataType* p_e1_grid;
// Raw Problem Size
index_t M;
index_t N;
index_t K;
index_t O;
index_t batch_count;
arr3 a_g_m_k_lengths;
arr3 a_g_m_k_strides;
arr3 b0_g_n_k_lengths;
arr3 b0_g_n_k_strides;
std::array<arr3, NumD0Tensor> d0s_g_m_n_lengths;
std::array<arr3, NumD0Tensor> d0s_g_m_n_strides;
arr3 b1_g_o_n_lengths;
arr3 b1_g_o_n_strides;
std::array<arr3, NumD1Tensor> d1s_g_m_o_lengths;
std::array<arr3, NumD1Tensor> d1s_g_m_o_strides;
arr3 e1_g_m_o_lengths;
arr3 e1_g_m_o_strides;
AElementwiseOperation a_element_op;
B0ElementwiseOperation b0_element_op;
AccElementwiseOperation acc_element_op;
B1ElementwiseOperation b1_element_op;
CDE1ElementwiseOperation cde1_element_op;
// Grid descriptors and other mem calculators
AGridDesc a_grid_desc;
B0GridDesc b0_grid_desc;
D0sGridDesc d0s_grid_desc;
B1GridDesc b1_grid_desc;
D1sGridDesc d1s_grid_desc;
typename GridwiseOp::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock;
E1GridDesc e1_grid_desc_m_n;
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock;
typename GridwiseOp::DefaultBlock2ETileMap block_2_etile_map;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
};
// check if DsLayout is supported
template <typename RefLayout, typename DsLayout, const index_t NumDTensor>
static constexpr bool CheckDLayout()
{
bool valid = true;
// iterate over DLayout tuple
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// if RefLayout and DLayout are same, keep valid true, otherwise false
valid = valid && is_same_v<RefLayout, DLayout>;
});
return valid;
return DeviceGemmGemmCommon::IsSupportedArgument(arg);
}
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
{
// Print lambda with env check and printf() style formmating.
const char* curFunc = __func__;
auto print = [&curFunc](const char* format, ...) -> void {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wformat-nonliteral"
#endif
va_list args;
va_start(args, format);
std::vfprintf(stdout, format, args);
va_end(args);
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
}
};
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
print("DeviceOp: Arch err\n");
return false;
}
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
{
if(ck::is_gfx11_supported())
{
print("DeviceOp: gfx 11 does not support fp8\n");
return false;
}
}
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
print("DeviceOp: Acc0 Type err\n");
return false;
}
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: A layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B0 layout must be Column\n");
return false;
}
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>()))
{
print("DeviceOp: All D0s layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
{
print("DeviceOp: B1 layout must be Column or Row\n");
return false;
}
if constexpr(!(CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>()))
{
print("DeviceOp: All D1s layout must be Row\n");
return false;
}
if constexpr(!(is_same_v<E1Layout, tensor_layout::gemm::RowMajor>))
{
print("DeviceOp: C layout must be Row\n");
return false;
}
// Other padding modes have not been tested and do not get checked individually.
if constexpr(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MNKOPadding)
{
print("Padding mode must be default or MNKO\n");
return false;
}
// Per wmma dimensions not equal to 16 are very untested.
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
{
print("M, L, N per Wmma must be 16\n");
return false;
}
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
arg.b0_grid_desc,
arg.d0s_grid_desc,
arg.b1_grid_desc,
arg.d1s_grid_desc,
arg.e1_grid_desc_m_n,
arg.block_2_etile_map))
{
return false;
}
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
const auto cde1_extent_lowest = arg.O;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
print("DeviceOp: Data Transfer Vector scalar err\n");
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
const auto b0_stride_lowest =
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
const auto e1_stride_lowest = arg.e1_g_m_o_strides[2];
// NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major
// and the lowest dimension stride is hardcoded to 1
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
e1_stride_lowest == 1))
{
print("DeviceOp: Data Vectorize transfer err\n");
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
{
return false;
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::RawArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
const index_t grid_size = arg.batch_count * M0 * N0;
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
constexpr TailNumber tn = tail_number;
const auto kernel =
kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3<DeviceOp,
GridwiseOp,
has_loop,
tn>;
return launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
};
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
return 0.0f;
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
{
return launch_kernel(std::integral_constant<bool, true>{},
std::integral_constant<TailNumber, TailNumber::Full>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Even>{});
}
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
{
return launch_kernel(std::integral_constant<bool, false>{},
std::integral_constant<TailNumber, TailNumber::Odd>{});
}
else
{
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
return 0.0f;
}
}
else
{
printf("Invalid pipeline version!\n");
return 0.0f;
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static auto MakeArgument(const ADataType* p_a0,
const B0DataType* p_b0,
std::array<const void*, NumD0Tensor> p_d0s,
@@ -855,20 +321,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op)
{
return RawArg{p_a0, p_b0,
p_d0s, p_b1,
p_d1s, p_e1,
MRaw, NRaw,
KRaw, Gemm1NRaw,
Batch, StrideA0,
StrideB0, StrideD0s,
StrideB1, StrideD1s,
StrideE1, BatchStrideA0,
BatchStrideB0, BatchStrideD0s,
BatchStrideB1, BatchStrideD1s,
BatchStrideE1, a0_element_op,
b0_element_op, cde0_element_op,
b1_element_op, cde1_element_op};
return Argument{p_a0, p_b0,
p_d0s, p_b1,
p_d1s, p_e1,
MRaw, NRaw,
KRaw, Gemm1NRaw,
Batch, StrideA0,
StrideB0, StrideD0s,
StrideB1, StrideD1s,
StrideE1, BatchStrideA0,
BatchStrideB0, BatchStrideD0s,
BatchStrideB1, BatchStrideD1s,
BatchStrideE1, a0_element_op,
b0_element_op, cde0_element_op,
b1_element_op, cde1_element_op};
}
// polymorphic
@@ -902,34 +368,34 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation c_element_op) override
{
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
p_d0s,
static_cast<const B1DataType*>(p_b1),
p_d1s,
static_cast<E1DataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideD0s,
StrideB1,
StrideD1s,
StrideE1,
BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
p_d0s,
static_cast<const B1DataType*>(p_b1),
p_d1s,
static_cast<E1DataType*>(p_c),
M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideD0s,
StrideB1,
StrideD1s,
StrideE1,
BatchStrideA,
BatchStrideB0,
BatchStrideD0s,
BatchStrideB1,
BatchStrideD1s,
BatchStrideE1,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op);
}
static auto MakeInvoker() { return Invoker{}; }