Wmma support for multiple ABD GEMM (#2803)

* multi_abd wmma support:

 - Add multiple A and B support to multiple D implementation (gridwise level)
 - Add multi_abd GEMM (device level)
 - Add instances (xdl parity)
 - Add tests (both xdl and wmma)
 - Add examples
 - Add ckProfiler support (both xdl and wmma)

* Fix bug in device print function

* Fix unused template parameter

* Fix batched gemm for multiABD gridwise implementation

* Fix gemm_universal_reduce with multiABDs gridwise implementation

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Enrico Degregori
2025-09-23 03:49:06 +02:00
committed by GitHub
parent de47ae2fdf
commit 3d29bff2f0
38 changed files with 5343 additions and 312 deletions

View File

@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -39,8 +40,8 @@ namespace ck {
/// @tparam BLayout B tensor data layout.
/// @tparam DsLayout D tensors data layouts.
/// @tparam ELayout E tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam AsDataType A tensors data types.
/// @tparam BsDataType B tensors data types.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
@@ -129,8 +130,8 @@ template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AsDataType,
typename BsDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
@@ -181,8 +182,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AsDataType,
BsDataType,
AccDataType,
CShuffleDataType,
DsDataType,
@@ -233,8 +234,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AsDataType,
BsDataType,
AccDataType,
CShuffleDataType,
DsDataType,
@@ -305,8 +306,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
using Base::CalculateMPadded;
using Base::CalculateNBlock;
using Base::CalculateNPadded;
using Base::MakeAGridDescriptor_AK0_M_AK1;
using Base::MakeBGridDescriptor_BK0_N_BK1;
using Base::MakeAsGridDescriptor_AK0_M_AK1;
using Base::MakeBsGridDescriptor_BK0_N_BK1;
using Base::MakeDEGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
@@ -320,24 +321,30 @@ struct GridwiseGemm_wmma_cshuffle_v3
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
using Base::NumATensor;
using Base::NumBTensor;
using Base::NumDTensor;
using typename Base::AsGridPointer;
using typename Base::BsGridPointer;
using typename Base::DsGridPointer;
using AsDataType_ = AsDataType;
using BsDataType_ = BsDataType;
struct Problem
{
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumATensor> StrideAs_,
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t KBatch_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideAs{StrideAs_},
StrideBs{StrideBs_},
StrideDs{StrideDs_},
StrideE{StrideE_},
KBatch{KBatch_},
@@ -355,7 +362,15 @@ struct GridwiseGemm_wmma_cshuffle_v3
__host__ void Print() const
{
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
<< "SAs: {";
static_for<0, NumATensor, 1>{}([&](auto i) {
std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
});
std::cout << "}, " << "SBs: {";
static_for<0, NumBTensor, 1>{}([&](auto i) {
std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
});
std::cout << "}, ";
if constexpr(NumDTensor > 0)
{
std::cout << "SDs: { ";
@@ -373,8 +388,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumATensor> StrideAs;
std::array<index_t, NumBTensor> StrideBs;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t KBatch;
@@ -391,15 +406,15 @@ struct GridwiseGemm_wmma_cshuffle_v3
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
std::array<const void*, NumBTensor> p_bs_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumATensor> StrideAs_,
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t k_batch_,
@@ -407,9 +422,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
: Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_},
p_as_grid{},
p_bs_grid{},
p_ds_grid{},
p_e_grid{p_e_grid_},
a_element_op{a_element_op_},
@@ -417,9 +432,27 @@ struct GridwiseGemm_wmma_cshuffle_v3
cde_element_op{cde_element_op_},
is_reduce(is_reduce_)
{
// populate pointer, desc for As
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
// A pointer
p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
});
// populate pointer, desc for Bs
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
// B pointer
p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
});
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
});
}
@@ -434,8 +467,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
return (Problem::KBatch > 1) && (!is_reduce);
}
const ADataType* p_a_grid;
const BDataType* p_b_grid;
AsGridPointer p_as_grid;
BsGridPointer p_bs_grid;
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
@@ -452,29 +485,39 @@ struct GridwiseGemm_wmma_cshuffle_v3
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
// Note: in xdl implementation multiple AB supports one layout
// but multiple strides, so we create an array of offsets with
// the same values.
// It should be fixed later on. Once we will have a thread transfer
// more flexible.
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * karg.KRead / APackedSize;
static_for<0, NumATensor, 1>{}(
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
static_for<0, NumATensor, 1>{}(
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = k_id * karg.KRead / BPackedSize;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = k_id * k0_offset / BPackedSize;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
}
}
@@ -497,8 +540,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
std::array<index_t, NumATensor> a_k_split_offset;
std::array<index_t, NumBTensor> b_k_split_offset;
index_t c_reduce_offset;
};
@@ -514,8 +557,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
@@ -524,10 +567,10 @@ struct GridwiseGemm_wmma_cshuffle_v3
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
@@ -562,20 +605,20 @@ struct GridwiseGemm_wmma_cshuffle_v3
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_b_grid,
TailNum>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
@@ -595,10 +638,26 @@ struct GridwiseGemm_wmma_cshuffle_v3
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
{
// shift A matrices pointer for splitk
AsGridPointer p_as_grid_splitk;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
splitk_batch_offset.a_k_split_offset[i];
});
// shift B matrices pointer for splitk
BsGridPointer p_bs_grid_splitk;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
splitk_batch_offset.b_k_split_offset[i];
});
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg,

View File

@@ -22,8 +22,9 @@ template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AsDataType,
typename BsDataType,
typename BScaleType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
@@ -76,8 +77,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AsDataType,
BsDataType,
AccDataType,
CShuffleDataType,
DsDataType,
@@ -123,15 +124,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
PermuteA,
PermuteB>
{
using BScaleType = ck::half_t;
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AsDataType,
BsDataType,
AccDataType,
CShuffleDataType,
DsDataType,
@@ -202,8 +201,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Base::CalculateMPadded;
using Base::CalculateNBlock;
using Base::CalculateNPadded;
using Base::MakeAGridDescriptor_AK0_M_AK1;
using Base::MakeBGridDescriptor_BK0_N_BK1;
using Base::MakeAsGridDescriptor_AK0_M_AK1;
using Base::MakeBsGridDescriptor_BK0_N_BK1;
using Base::MakeDEGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
@@ -217,7 +216,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
using Base::NumATensor;
using Base::NumBTensor;
using Base::NumDTensor;
using typename Base::AsGridPointer;
using typename Base::BsGridPointer;
using typename Base::DsGridPointer;
struct Problem
@@ -225,8 +228,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumATensor> StrideAs_,
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleB_,
@@ -234,8 +237,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideAs{StrideAs_},
StrideBs{StrideBs_},
StrideDs{StrideDs_},
StrideE{StrideE_},
StrideScaleB{StrideScaleB_},
@@ -254,7 +257,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__host__ void Print() const
{
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
<< "SAs: {";
static_for<0, NumATensor, 1>{}([&](auto i) {
std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : "");
});
std::cout << "}, " << "SBs: {";
static_for<0, NumBTensor, 1>{}([&](auto i) {
std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : "");
});
std::cout << "}, ";
if constexpr(NumDTensor > 0)
{
std::cout << "SDs: { ";
@@ -273,8 +284,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumATensor> StrideAs;
std::array<index_t, NumBTensor> StrideBs;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t StrideScaleB;
@@ -292,15 +303,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
__host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
std::array<const void*, NumBTensor> p_bs_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumATensor> StrideAs_,
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleB_,
@@ -310,9 +321,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, StrideScaleB_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
: Problem{M_,
N_,
K_,
StrideAs_,
StrideBs_,
StrideDs_,
StrideE_,
StrideScaleB_,
k_batch_},
p_as_grid{},
p_bs_grid{},
p_ds_grid{},
p_e_grid{p_e_grid_},
p_b_scale_grid{p_b_scale_grid_},
@@ -321,6 +340,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
cde_element_op{cde_element_op_},
is_reduce(is_reduce_)
{
// populate pointer, desc for As
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
// A pointer
p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
});
// populate pointer, desc for Bs
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
// B pointer
p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
@@ -338,8 +373,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
return (Problem::KBatch > 1) && (!is_reduce);
}
const ADataType* p_a_grid;
const BDataType* p_b_grid;
AsGridPointer p_as_grid;
BsGridPointer p_bs_grid;
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
@@ -355,29 +390,39 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
// Note: in xdl implementation multiple AB supports one layout
// but multiple strides, so we create an array of offsets with
// the same values.
// It should be fixed later on. Once we will have a thread transfer
// more flexible.
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * karg.KRead / APackedSize;
static_for<0, NumATensor, 1>{}(
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; });
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
static_for<0, NumATensor, 1>{}(
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = k_id * karg.KRead / BPackedSize;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = k_id * k0_offset / BPackedSize;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
}
}
@@ -410,8 +455,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
std::array<index_t, NumATensor> a_k_split_offset;
std::array<index_t, NumBTensor> b_k_split_offset;
index_t scale_k_split_offset; // New member for scale matrix offset
index_t c_reduce_offset;
};
@@ -423,7 +468,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK, typename BScaleType>
template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK>
__device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
const BScaleType* p_b_scale_grid,
index_t block_n_id)
@@ -488,8 +533,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
const BScaleType* p_b_scale_grid,
@@ -499,10 +544,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
@@ -542,20 +587,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_b_grid,
TailNum>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
@@ -575,10 +620,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
{
// shift A matrices pointer for splitk
AsGridPointer p_as_grid_splitk;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_splitk(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
splitk_batch_offset.a_k_split_offset[i];
});
// shift B matrices pointer for splitk
BsGridPointer p_bs_grid_splitk;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_splitk(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
splitk_batch_offset.b_k_split_offset[i];
});
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_shared,

View File

@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -61,8 +62,8 @@ template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AsDataType,
typename BsDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
@@ -119,6 +120,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
using LDSTypeA =
typename std::conditional<(NumATensor > 1),
ComputeTypeA,
remove_cvref_t<tuple_element_t<0, AsDataType>>>::type;
using LDSTypeB =
typename std::conditional<(NumBTensor > 1),
ComputeTypeB,
remove_cvref_t<tuple_element_t<0, BsDataType>>>::type;
static constexpr auto EShuffleBlockTransferScalarPerVector =
CDEShuffleBlockTransferScalarPerVectors{}[I0];
@@ -136,14 +149,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
if constexpr(is_same_v<remove_cvref_t<LDSTypeA>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
if constexpr(is_same_v<remove_cvref_t<LDSTypeB>, pk_i4_t>)
return 2;
else
return 1;
@@ -230,6 +243,31 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
static constexpr auto MakeAsGridPointer()
{
return generate_tuple(
[&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
return static_cast<const ADataType_*>(nullptr);
},
Number<NumATensor>{});
}
static constexpr auto MakeBsGridPointer()
{
return generate_tuple(
[&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
return static_cast<const BDataType_*>(nullptr);
},
Number<NumBTensor>{});
}
using AsGridPointer = decltype(MakeAsGridPointer());
using BsGridPointer = decltype(MakeBsGridPointer());
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
@@ -314,6 +352,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
__host__ __device__ static auto
MakeAsGridDescriptor_AK0_M_AK1(const index_t M,
const index_t MPad,
const index_t K,
const index_t KPad,
const std::array<index_t, NumATensor>& StrideAs,
const index_t AK0)
{
return generate_tuple(
[&](auto i) {
return MakeAGridDescriptor_AK0_M_AK1(M, MPad, K, KPad, StrideAs[i], AK0);
},
Number<NumATensor>{});
}
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
@@ -330,7 +383,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
static_assert(!(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> &&
static_assert(!(is_same_v<remove_cvref_t<LDSTypeB>, pk_i4_t> &&
GemmSpec != GemmSpecialization::Default),
"pk_i4_t does not support padding");
@@ -424,6 +477,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
__host__ __device__ static auto
MakeBsGridDescriptor_BK0_N_BK1(const index_t K,
const index_t KPad,
const index_t N,
const index_t NPad,
const std::array<index_t, NumBTensor>& StrideBs,
const index_t BK0)
{
return generate_tuple(
[&](auto i) {
return MakeBGridDescriptor_BK0_N_BK1(K, KPad, N, NPad, StrideBs[i], BK0);
},
Number<NumBTensor>{});
}
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&)
{
@@ -557,7 +625,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeA) / APackedSize;
constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
@@ -604,20 +672,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
constexpr auto KThreadRead = 64 / MPerWmma;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
? 1
: 128 / (AK1Number * M0 * sizeof(ADataType));
: 128 / (AK1Number * M0 * sizeof(LDSTypeA));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=n0
constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128)
constexpr auto mpair = (AK1Number * MPerWmma * sizeof(LDSTypeA) > 128)
? 1
: ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0
: ((128 / (AK1Number * MPerWmma * sizeof(LDSTypeA))) > M0
? M0
: 128 / (AK1Number * MPerWmma * sizeof(ADataType)));
: 128 / (AK1Number * MPerWmma * sizeof(LDSTypeA)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
@@ -694,7 +762,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
// NLdsLayer * K0 as logical Bank
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeB) / BPackedSize;
constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
@@ -738,20 +806,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
constexpr auto KThreadRead = 64 / NPerWmma;
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128)
? 1
: 128 / (BK1Number * N0 * sizeof(BDataType));
: 128 / (BK1Number * N0 * sizeof(LDSTypeB));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128)
constexpr auto npair = (BK1Number * NPerWmma * sizeof(LDSTypeB) > 128)
? 1
: ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0
: ((128 / (BK1Number * NPerWmma * sizeof(LDSTypeB))) > N0
? N0
: 128 / (BK1Number * NPerWmma * sizeof(BDataType)));
: 128 / (BK1Number * NPerWmma * sizeof(LDSTypeB)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
@@ -836,8 +904,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
ADataType,
BDataType,
LDSTypeA,
LDSTypeB,
ComputeTypeA,
ComputeTypeB,
AccDataType,
@@ -1120,11 +1188,24 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize +
b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize),
c_block_size * sizeof(CShuffleDataType));
}
template <index_t numElements, typename Type>
__device__ __forceinline__ static auto get_first_element_workaround(Type& array)
{
if constexpr(numElements > 1)
{
return array;
}
else
{
return array[I0];
}
}
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -1133,13 +1214,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
__device__ static void Run(AsGridPointer p_as_grid,
BsGridPointer p_bs_grid,
DsGridPointer p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const AGridDesc_AK0_M_K1& as_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& bs_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
@@ -1152,10 +1233,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const index_t& num_k_block_per_scale,
BScaleStruct& b_scale_struct)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto as_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize());
},
Number<NumATensor>{});
const auto bs_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize());
},
Number<NumBTensor>{});
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -1183,66 +1274,144 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// workaround because v7r2 is not as general as v4r1
auto get_a_blockwise_transfer = [&]() {
if constexpr(NumATensor > 1)
{
const auto idx_as_block_begin = generate_tuple(
[&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
return ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
AsDataType,
Tuple<LDSTypeA>,
AGridDesc_AK0_M_K1,
decltype(tie(a_block_desc_ak0_m_ak1)),
AElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
uniform_sequence_gen_t<NumATensor, AThreadTransferSrcResetCoordinateAfterRun>,
Sequence<true>,
BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1,
idx_as_block_begin,
tie(a_block_desc_ak0_m_ak1),
make_tuple(make_multi_index(0, 0, 0)),
a_element_op};
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
remove_cvref_t<tuple_element_t<0, AsDataType>>,
remove_cvref_t<tuple_element_t<0, AsDataType>>,
decltype(as_grid_desc_ak0_m_ak1[I0]),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
as_grid_desc_ak0_m_ak1[I0],
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
auto a_blockwise_copy = get_a_blockwise_transfer();
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// workaround because v7r2 is not as general as v4r1
auto get_b_blockwise_transfer = [&]() {
if constexpr(NumBTensor > 1)
{
const auto idx_bs_block_begin = generate_tuple(
[&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
return ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
BsDataType,
Tuple<LDSTypeB>,
BGridDesc_BK0_N_K1,
decltype(tie(b_block_desc_bk0_n_bk1)),
BElementwiseOperation,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
uniform_sequence_gen_t<NumBTensor, BThreadTransferSrcResetCoordinateAfterRun>,
Sequence<true>,
BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1,
idx_bs_block_begin,
tie(b_block_desc_bk0_n_bk1),
make_tuple(make_multi_index(0, 0, 0)),
b_element_op};
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
remove_cvref_t<tuple_element_t<0, BsDataType>>,
remove_cvref_t<tuple_element_t<0, BsDataType>>,
decltype(bs_grid_desc_bk0_n_bk1[I0]),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
bs_grid_desc_bk0_n_bk1[I0],
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
auto b_blockwise_copy = get_b_blockwise_transfer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
@@ -1250,12 +1419,12 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
sizeof(ADataType) /
APackedSize),
reinterpret_cast<LDSTypeB*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
sizeof(LDSTypeA) /
APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
@@ -1267,25 +1436,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
(as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) /
KPerBlock);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
b_scale_struct,
num_k_block_main_loop,
num_k_block_per_scale);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
get_first_element_workaround<NumATensor>(as_grid_desc_ak0_m_ak1),
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
get_first_element_workaround<NumATensor>(as_grid_buf),
a_block_buf,
a_block_slice_copy_step,
get_first_element_workaround<NumBTensor>(bs_grid_desc_bk0_n_bk1),
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
get_first_element_workaround<NumBTensor>(bs_grid_buf),
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
b_scale_struct,
num_k_block_main_loop,
num_k_block_per_scale);
// shuffle C and write out
{