mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] Multiple-ABD GEMM example (#2788)
* Multi ABD - initial commit * Clang-foramt fix * block gemm, unify the name of CDataType * Apply chnages to mem-pipeline * Rollback prefix for DType and Layout * Gemm Kernel Basic, rename * WMMA config * Grouped GEMM * Clang-format * Dropout, name * Review v2 * Move element_wise fn to unnary, remov old ones fn * clang-format * Fix issue review * WP operator adjust to universal gemm * v2 prepare * Remove unused comment * Remove vectorsize * Rollback * Adjust pipeline for abd * Shuffle argument * CI-fail fix quant * Fix ag_br pipeline * Failing tests * Typo * Single argument support
This commit is contained in:
@@ -90,10 +90,10 @@ struct BatchedGemmKernel
|
||||
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, CDataType>::value,
|
||||
"C/ELayout and C/EDataType must be scalars.");
|
||||
"C/CLayout and C/EDataType must be scalars.");
|
||||
|
||||
struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<>
|
||||
{
|
||||
|
||||
@@ -89,7 +89,7 @@ struct GemmKernel
|
||||
/// @brief Specify the layout configurations for A, B, E and D
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, E and D
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
@@ -106,10 +106,10 @@ struct GemmKernel
|
||||
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, ELayout>::value &&
|
||||
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, EDataType>::value,
|
||||
"C/ELayout and C/EDataType must be scalars.");
|
||||
"C/CLayout and C/EDataType must be scalars.");
|
||||
|
||||
static constexpr index_t NumATensor = 1;
|
||||
static constexpr index_t NumBTensor = 1;
|
||||
|
||||
193
include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp
Normal file
193
include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp
Normal file
@@ -0,0 +1,193 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The MultiABD GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref GemmKernelMultiABD "GemmKernelMultiABD" when creating
|
||||
/// kernel arguments object. It contain all necessary information required to build proper
|
||||
/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem
|
||||
/// configuration by stating all required information like M,N,K sizes and respective strides.
|
||||
/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required).
|
||||
/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required).
|
||||
/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not
|
||||
/// required).
|
||||
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
|
||||
struct GemmMultiABDHostArgs
|
||||
{
|
||||
CK_TILE_HOST GemmMultiABDHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
|
||||
const std::array<const void*, NumBTensor>& bs_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const std::array<index_t, NumATensor>& stride_As_,
|
||||
const std::array<index_t, NumBTensor>& stride_Bs_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
: as_ptr(as_ptr_),
|
||||
bs_ptr(bs_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
e_ptr(e_ptr_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
stride_As(stride_As_),
|
||||
stride_Bs(stride_Bs_),
|
||||
stride_Ds(stride_Ds_),
|
||||
stride_E(stride_E_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const std::array<const void*, NumATensor> as_ptr;
|
||||
const std::array<const void*, NumBTensor> bs_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
const std::array<index_t, NumATensor> stride_As;
|
||||
const std::array<index_t, NumBTensor> stride_Bs;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
index_t stride_E;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GemmKernelMultiABD
|
||||
{
|
||||
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
/// functions.
|
||||
using UniversalGemmKernel =
|
||||
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
/// @brief Specify the layout configurations for A, B, E and D
|
||||
using AsLayout = remove_cvref_t<typename GemmPipeline::AsLayout>;
|
||||
using BsLayout = remove_cvref_t<typename GemmPipeline::BsLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, E and D
|
||||
using AsDataType = remove_cvref_t<typename GemmPipeline::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename GemmPipeline::BsDataType>;
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
|
||||
/// @brief ALayout and ADataType are expected to be a tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, AsLayout>::value &&
|
||||
is_detected<is_tuple, AsDataType>::value,
|
||||
"ALayout and ADataType must be a tuple.");
|
||||
|
||||
/// @brief BLayout and BDataType are expected to be a tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, BsLayout>::value &&
|
||||
is_detected<is_tuple, BsDataType>::value,
|
||||
"BLayout and BDataType must be a tuple.");
|
||||
|
||||
/// @brief CLayout and EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, EDataType>::value,
|
||||
"CLayout and EDataType must be a scalar.");
|
||||
|
||||
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, DsLayout>::value &&
|
||||
is_detected<is_tuple, DsDataType>::value &&
|
||||
DsLayout::size() == DsDataType::size() && DsLayout::size() > 0,
|
||||
"DsLayout and DsDataType must be tuples and must have the same size.");
|
||||
|
||||
/// @brief The sizes of NumATensor, NumBTensor and NumDTensor is set by the user."
|
||||
static constexpr index_t NumATensor = AsDataType::size();
|
||||
static constexpr index_t NumBTensor = BsDataType::size();
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
CK_TILE_HOST static auto GetName() -> const std::string
|
||||
{
|
||||
return UniversalGemmKernel::GetName();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::GridSize(M, N, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
MakeKernelArgs(const GemmMultiABDHostArgs<NumATensor, NumBTensor, NumDTensor>& hostArgs) ->
|
||||
typename UniversalGemmKernel::KernelArgs
|
||||
{
|
||||
/// @brief Universal GEMM requires array objects and corresponding stride information for
|
||||
/// matrices A, B, and D.
|
||||
return UniversalGemmKernel::MakeKernelArgs(
|
||||
UniversalGemmHostArgs<NumATensor, NumBTensor, NumDTensor>(hostArgs.as_ptr,
|
||||
hostArgs.bs_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.k_batch,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_As,
|
||||
hostArgs.stride_Bs,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool
|
||||
{
|
||||
// Currently MultiABD kernel doesn't support k_batch > 1
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
|
||||
{
|
||||
UniversalGemmKernel{}.template operator()(kargs);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -95,7 +95,7 @@ struct GemmKernelMultiD
|
||||
/// @brief Specify the layout configurations for A, B, E and D
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, E and D
|
||||
@@ -114,10 +114,10 @@ struct GemmKernelMultiD
|
||||
!is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars.");
|
||||
|
||||
/// @brief ELayout and EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, ELayout>::value &&
|
||||
/// @brief CLayout and EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, EDataType>::value,
|
||||
"ELayout and EDataType must be scalars.");
|
||||
"CLayout and EDataType must be scalars.");
|
||||
|
||||
/// @brief DsLayout and DsDataType are expected to be tuple, not a scalar.
|
||||
static_assert(is_detected<is_tuple, DsLayout>::value &&
|
||||
|
||||
@@ -120,10 +120,10 @@ struct GroupedGemmKernel
|
||||
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
/// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, CDataType>::value,
|
||||
"C/ELayout and C/EDataType must be scalars.");
|
||||
"C/CLayout and C/EDataType must be scalars.");
|
||||
|
||||
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
|
||||
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
@@ -364,12 +364,8 @@ struct GroupedGemmKernel
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
|
||||
b_block_window[Base::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
|
||||
@@ -157,23 +157,23 @@ struct UniversalGemmKernel
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
static constexpr bool ADataTypeIsTuple =
|
||||
is_detected<is_tuple, typename GemmPipeline::ADataType>::value;
|
||||
is_detected<is_tuple, typename GemmPipeline::AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple =
|
||||
is_detected<is_tuple, typename GemmPipeline::BDataType>::value;
|
||||
is_detected<is_tuple, typename GemmPipeline::BsDataType>::value;
|
||||
static constexpr bool DDataTypeIsTuple =
|
||||
is_detected<is_tuple, typename EpiloguePipeline::DsDataType>::value;
|
||||
static constexpr bool ALayoutIsTuple =
|
||||
is_detected<is_tuple, typename GemmPipeline::ALayout>::value;
|
||||
is_detected<is_tuple, typename GemmPipeline::AsLayout>::value;
|
||||
static constexpr bool BLayoutIsTuple =
|
||||
is_detected<is_tuple, typename GemmPipeline::BLayout>::value;
|
||||
is_detected<is_tuple, typename GemmPipeline::BsLayout>::value;
|
||||
static constexpr bool DLayoutIsTuple =
|
||||
is_detected<is_tuple, typename EpiloguePipeline::DsLayout>::value;
|
||||
|
||||
using AsLayout = std::conditional_t<ALayoutIsTuple,
|
||||
remove_cvref_t<typename GemmPipeline::ALayout>,
|
||||
remove_cvref_t<typename GemmPipeline::AsLayout>,
|
||||
remove_cvref_t<tuple<typename GemmPipeline::ALayout>>>;
|
||||
using BsLayout = std::conditional_t<BLayoutIsTuple,
|
||||
remove_cvref_t<typename GemmPipeline::BLayout>,
|
||||
remove_cvref_t<typename GemmPipeline::BsLayout>,
|
||||
remove_cvref_t<tuple<typename GemmPipeline::BLayout>>>;
|
||||
|
||||
using DsLayout = std::conditional_t<DLayoutIsTuple,
|
||||
@@ -181,11 +181,11 @@ struct UniversalGemmKernel
|
||||
remove_cvref_t<tuple<typename EpiloguePipeline::DsLayout>>>;
|
||||
|
||||
using AsDataType = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<typename GemmPipeline::ADataType>,
|
||||
remove_cvref_t<typename GemmPipeline::AsDataType>,
|
||||
remove_cvref_t<tuple<typename GemmPipeline::ADataType>>>;
|
||||
|
||||
using BsDataType = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<typename GemmPipeline::BDataType>,
|
||||
remove_cvref_t<typename GemmPipeline::BsDataType>,
|
||||
remove_cvref_t<tuple<typename GemmPipeline::BDataType>>>;
|
||||
|
||||
using DsDataType =
|
||||
@@ -193,9 +193,12 @@ struct UniversalGemmKernel
|
||||
remove_cvref_t<typename EpiloguePipeline::DsDataType>,
|
||||
remove_cvref_t<tuple<typename EpiloguePipeline::DsDataType>>>;
|
||||
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
// Get the persistent kernel if the pipeline has it available
|
||||
@@ -483,7 +486,7 @@ struct UniversalGemmKernel
|
||||
bool DTesnorIsValid = {true};
|
||||
static_for<0, NumDTensor, 1>{}([&](auto index) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if(std::is_same_v<DiLayout, ELayout> == false)
|
||||
if(std::is_same_v<DiLayout, CLayout> == false)
|
||||
{
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
@@ -529,7 +532,7 @@ struct UniversalGemmKernel
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
@@ -724,7 +727,7 @@ struct UniversalGemmKernel
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
@@ -818,7 +821,7 @@ struct UniversalGemmKernel
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
@@ -975,8 +978,8 @@ struct UniversalGemmKernel
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile =
|
||||
GemmPipeline{}(as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
@@ -1031,8 +1034,13 @@ struct UniversalGemmKernel
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}(
|
||||
as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window,
|
||||
AElementWise{},
|
||||
bs_block_window,
|
||||
BElementWise{},
|
||||
num_loop,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
Reference in New Issue
Block a user