mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] Introduces a new GEMM API that splits the existing basic GEMM class into multiple specialized classes. (#2520)
* Init commit new API * apply clang-format * PreShuffle preapring * Apply Preshuffle condition to universal_gemm * Fix: convert size_t to index_t * Review changes * Mode 100755 -> 100644 --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -9,35 +9,41 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs</*NumDTensor = 0*/>
|
||||
/// @brief The Batched GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref BatchedGemmKernel "BatchedGemmKernel" when creating kernel
|
||||
/// arguments object. It contain all necessary information required to build proper kernel
|
||||
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
|
||||
/// stating all required information like M,N,K sizes and respective strides.
|
||||
struct BatchedGemmHostArgs : public ck_tile::UniversalGemmHostArgs<>
|
||||
{
|
||||
CK_TILE_HOST BatchedGemmHostArgs() = default;
|
||||
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
ck_tile::index_t k_batch_,
|
||||
ck_tile::index_t M_,
|
||||
ck_tile::index_t N_,
|
||||
ck_tile::index_t K_,
|
||||
ck_tile::index_t stride_A_,
|
||||
ck_tile::index_t stride_B_,
|
||||
ck_tile::index_t stride_C_,
|
||||
ck_tile::index_t batch_stride_A_,
|
||||
ck_tile::index_t batch_stride_B_,
|
||||
ck_tile::index_t batch_stride_C_,
|
||||
ck_tile::index_t batch_count_)
|
||||
: GemmHostArgs(a_ptr_,
|
||||
b_ptr_,
|
||||
{},
|
||||
c_ptr_,
|
||||
k_batch_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
stride_A_,
|
||||
stride_B_,
|
||||
{},
|
||||
stride_C_),
|
||||
CK_TILE_HOST explicit BatchedGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
ck_tile::index_t k_batch_,
|
||||
ck_tile::index_t M_,
|
||||
ck_tile::index_t N_,
|
||||
ck_tile::index_t K_,
|
||||
ck_tile::index_t stride_A_,
|
||||
ck_tile::index_t stride_B_,
|
||||
ck_tile::index_t stride_C_,
|
||||
ck_tile::index_t batch_stride_A_,
|
||||
ck_tile::index_t batch_stride_B_,
|
||||
ck_tile::index_t batch_stride_C_,
|
||||
ck_tile::index_t batch_count_)
|
||||
: UniversalGemmHostArgs<>({a_ptr_},
|
||||
{b_ptr_},
|
||||
{/*ds_ptr*/},
|
||||
c_ptr_,
|
||||
k_batch_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
{stride_A_},
|
||||
{stride_B_},
|
||||
{/*stride_Ds_*/},
|
||||
stride_C_),
|
||||
batch_stride_A(batch_stride_A_),
|
||||
batch_stride_B(batch_stride_B_),
|
||||
batch_stride_E(batch_stride_C_),
|
||||
@@ -52,36 +58,43 @@ struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs</*NumDTensor = 0*/>
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
|
||||
struct BatchedGemmKernel
|
||||
{
|
||||
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
/// functions.
|
||||
using UniversalGemmKernel =
|
||||
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
using GemmKernelArgs = typename ck_tile::GemmKernelArgs<>;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
using ADataType = typename Base::ADataType;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using CDataType = typename Base::EDataType;
|
||||
/// @brief Specify the layout configurations for A, B, E and D
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
using TilePartitioner = typename Base::TilePartitioner;
|
||||
using GemmPipeline = typename Base::GemmPipeline;
|
||||
using EpiloguePipeline = typename Base::EpiloguePipeline;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using CLayout = typename Base::ELayout;
|
||||
/// @brief Specify the data type configurations for A, B, E and D
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
|
||||
static_assert(
|
||||
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
|
||||
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
|
||||
static_assert(
|
||||
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
struct BatchedGemmKernelArgs : GemmKernelArgs
|
||||
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, CDataType>::value,
|
||||
"C/ELayout and C/EDataType must be scalars.");
|
||||
|
||||
struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<>
|
||||
{
|
||||
index_t batch_stride_A;
|
||||
index_t batch_stride_B;
|
||||
@@ -91,27 +104,41 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
using KernelArgs = BatchedGemmKernelArgs;
|
||||
|
||||
__host__ static constexpr auto
|
||||
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count)
|
||||
[[nodiscard]] CK_TILE_HOST static auto GetName() -> const std::string
|
||||
{
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) -> dim3
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); }
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
||||
{
|
||||
return dim3(UniversalGemmKernel::KernelBlockSize);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr BatchedGemmKernelArgs
|
||||
MakeKernelArgs(const BatchedGemmHostArgs& hostArgs)
|
||||
{
|
||||
return BatchedGemmKernelArgs{{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
{},
|
||||
return BatchedGemmKernelArgs{{hostArgs.as_ptr,
|
||||
hostArgs.bs_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
{},
|
||||
hostArgs.stride_As,
|
||||
hostArgs.stride_Bs,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E,
|
||||
hostArgs.k_batch},
|
||||
hostArgs.batch_stride_A,
|
||||
@@ -125,6 +152,12 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
|
||||
{
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
|
||||
@@ -134,18 +167,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
|
||||
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk);
|
||||
|
||||
// options
|
||||
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
|
||||
const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A +
|
||||
splitk_batch_offset.a_k_split_offset;
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + batch_offset_A +
|
||||
splitk_batch_offset.as_k_split_offset[0];
|
||||
|
||||
const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
|
||||
const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B +
|
||||
splitk_batch_offset.b_k_split_offset;
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + batch_offset_B +
|
||||
splitk_batch_offset.bs_k_split_offset[0];
|
||||
|
||||
const auto batch_stride_E = __builtin_amdgcn_readfirstlane(kargs.batch_stride_E);
|
||||
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_E);
|
||||
@@ -154,7 +187,8 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
UniversalGemmKernel::RunGemm(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
185
include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp
Normal file
185
include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp
Normal file
@@ -0,0 +1,185 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The MultiD GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref GemmKernelMultiD "GemmKernelMultiD" when creating kernel
|
||||
/// arguments object. It contain all necessary information required to build proper kernel
|
||||
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
|
||||
/// stating all required information like M,N,K sizes and respective strides. NumDTensor
|
||||
/// describes the number of D tensors.
|
||||
template <index_t NumDTensor = 1>
|
||||
struct GemmMultiDHostArgs
|
||||
{
|
||||
CK_TILE_HOST GemmMultiDHostArgs() = default;
|
||||
CK_TILE_HOST GemmMultiDHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
e_ptr(e_ptr_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
stride_A(stride_A_),
|
||||
stride_B(stride_B_),
|
||||
stride_Ds(stride_Ds_),
|
||||
stride_E(stride_E_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
index_t stride_E;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GemmKernelMultiD
|
||||
{
|
||||
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
/// functions.
|
||||
using UniversalGemmKernel =
|
||||
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
/// @brief Specify the layout configurations for A, B, E and D
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, E and D
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
|
||||
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, ALayout>::value &&
|
||||
!is_detected<is_tuple, ADataType>::value,
|
||||
"ALayout and ADataType must be scalars.");
|
||||
|
||||
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, BLayout>::value &&
|
||||
!is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars.");
|
||||
|
||||
/// @brief ELayout and EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, ELayout>::value &&
|
||||
!is_detected<is_tuple, EDataType>::value,
|
||||
"ELayout and EDataType must be scalars.");
|
||||
|
||||
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, DsLayout>::value &&
|
||||
is_detected<is_tuple, DsDataType>::value &&
|
||||
DsLayout::size() == DsDataType::size() && DsLayout::size() > 0,
|
||||
"DsLayout and DsDataType must be tuples and must have the same size.");
|
||||
|
||||
/// @brief The sizes of NumATensor and NumBTensor have always been 1; the size of D is set by
|
||||
/// the user."
|
||||
static constexpr index_t NumATensor = 1;
|
||||
static constexpr index_t NumBTensor = 1;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
CK_TILE_HOST static auto GetName() -> const std::string
|
||||
{
|
||||
return UniversalGemmKernel::GetName();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::GridSize(M, N, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
MakeKernelArgs(const GemmMultiDHostArgs<NumDTensor>& hostArgs) ->
|
||||
typename UniversalGemmKernel::KernelArgs
|
||||
{
|
||||
/// @brief Universal GEMM requires array objects and corresponding stride information for
|
||||
/// matrices A, B, and D.
|
||||
return UniversalGemmKernel::MakeKernelArgs(
|
||||
UniversalGemmHostArgs<NumATensor, NumBTensor, NumDTensor>({hostArgs.a_ptr},
|
||||
{hostArgs.b_ptr},
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.k_batch,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
{hostArgs.stride_A},
|
||||
{hostArgs.stride_B},
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
|
||||
{
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
|
||||
{
|
||||
UniversalGemmKernel{}.template operator()(kargs);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -16,37 +16,116 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The Grouped GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel
|
||||
/// arguments object. It contain all necessary information required to build proper kernel
|
||||
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
|
||||
/// stating all required information like M,N,K sizes and respective strides.
|
||||
struct GroupedGemmHostArgs
|
||||
{
|
||||
CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_E_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
e_ptr(e_ptr_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
stride_A(stride_A_),
|
||||
stride_B(stride_B_),
|
||||
stride_E(stride_E_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
|
||||
union
|
||||
{
|
||||
index_t stride_E;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
struct GemmTransKernelArg
|
||||
{
|
||||
GemmKernelArgs<> group_karg;
|
||||
UniversalGemmKernelArgs<> group_karg;
|
||||
ck_tile::index_t block_start;
|
||||
ck_tile::index_t block_end;
|
||||
|
||||
GemmTransKernelArg() = delete;
|
||||
GemmTransKernelArg(GemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end)
|
||||
GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end)
|
||||
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
|
||||
{
|
||||
}
|
||||
|
||||
GemmTransKernelArg(GemmKernelArgs<>&& karg) : group_karg{karg}, block_start{0}, block_end{0} {}
|
||||
GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg)
|
||||
: group_karg{karg}, block_start{0}, block_end{0}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
|
||||
struct GroupedGemmKernel
|
||||
{
|
||||
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
/// functions.
|
||||
using Base = UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
//// @brief Specify the layout configurations for A, B, C/E
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, C/E
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
|
||||
static_assert(
|
||||
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
|
||||
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
|
||||
static_assert(
|
||||
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, CDataType>::value,
|
||||
"C/ELayout and C/EDataType must be scalars.");
|
||||
|
||||
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
|
||||
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
@@ -65,8 +144,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
GetWorkSpaceSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs) -> std::size_t
|
||||
CK_TILE_HOST static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
-> std::size_t
|
||||
{
|
||||
return gemm_descs.size() * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
@@ -95,8 +174,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
|
||||
CK_TILE_HOST static auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
{
|
||||
index_t grid_size = 0;
|
||||
for(const auto& it_desc : gemm_descs)
|
||||
@@ -107,8 +185,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
MakeKargs(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
|
||||
CK_TILE_HOST static auto MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
-> std::vector<GemmTransKernelArg>
|
||||
{
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
@@ -138,18 +215,19 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
grid_size += grid_size_grp;
|
||||
|
||||
auto karg = GemmKernelArgs<>{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
|
||||
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
|
||||
{},
|
||||
type_convert<CDataType*>(gemm_descs[i].e_ptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
{},
|
||||
stride_e,
|
||||
gemm_descs[i].k_batch};
|
||||
auto karg =
|
||||
UniversalGemmKernelArgs<>{{type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
|
||||
{type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
|
||||
{/*ds_ptr*/},
|
||||
type_convert<CDataType*>(gemm_descs[i].e_ptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
{stride_a},
|
||||
{stride_b},
|
||||
{/*stride_ds*/},
|
||||
stride_e,
|
||||
gemm_descs[i].k_batch};
|
||||
|
||||
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
|
||||
}
|
||||
@@ -181,7 +259,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
Run(kargs.group_karg, block_idx_2d, block_idx_z);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void Run(const GemmKernelArgs<>& kargs,
|
||||
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<>& kargs,
|
||||
const tuple<index_t, index_t>& block_idx_2d,
|
||||
const index_t block_idx_z) const
|
||||
{
|
||||
@@ -192,10 +270,10 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
|
||||
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) +
|
||||
splitk_batch_offset.as_k_split_offset[0];
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) +
|
||||
splitk_batch_offset.bs_k_split_offset[0];
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
@@ -208,7 +286,15 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
}
|
||||
else
|
||||
{
|
||||
this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
Base::RunGemm({a_ptr},
|
||||
{b_ptr},
|
||||
{/*ds_ptr*/},
|
||||
c_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,7 +310,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
|
||||
* batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
@@ -234,7 +321,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GemmKernelArgs<>& kargs,
|
||||
const UniversalGemmKernelArgs<>& kargs,
|
||||
const typename Base::SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
@@ -242,7 +329,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, {}, c_ptr, kargs, splitk_batch_offset);
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
@@ -258,8 +345,12 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
|
||||
b_block_window[Base::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
|
||||
1169
include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp
Normal file
1169
include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user