mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
Redesign the DPP8 GEMM kernel to use warp-wise component (#863)
* Redesign the DPP8 GEMM kernel to use warp-wise component * Review: Improve error messages * Review: Remove unnecessary empty lines * Review: Fix M, N per thread names * Review: Rename mfma_input_type to dpp_input_type * Review: Fix tensor adaptor; remove unnecessary element * Review: Remove calls to dpp_gemm's MakeCDescriptor * Review: Add blockwise doc, change function names to include dimension names * Review: Remove duplicated code; Move Block2CtileMap alias to the top of the file * Review: Add __restrict__ keywords * Review: Use MatrixPadder for padding A, B, C matrices * Review: Remove hardcoded datatypes * Review: Change names from FloatX to XDataType * Review: Introduce AK0 and BK0 instead of a single K0 * Review: Remove construction of dpp_datatypes object * Review: Rename DppInstrRunner to DppLanegroupGemm
This commit is contained in:
committed by
GitHub
parent
3786bfe1cc
commit
37a8c1f756
@@ -1,370 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/amd_gemm_dpp.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* DPP8 version of blockwise GEMM algorithm. It uses DPP8 instruction modifier to limit
|
||||
* the data loaded from LDS to registers.
|
||||
*
|
||||
* The algorithm groups threads into groups of size `dpp8::lane_group_size` and splits the matrix C
|
||||
* between them in such a way that threads from the same group need the same chunk of either
|
||||
* matrix A (or B, respectively). Without the usage of DPP8, each thread would need to load the
|
||||
* whole chunk from LDS to its own register space.
|
||||
* Usage of DPP8 modifiers allow each thread to load less data, exactly `1 / dpp8::lane_group_size`
|
||||
* of the chunk, and then share that data with other threads from the same lane group.
|
||||
*
|
||||
* Assumptions coming from the usage of DPP8:
|
||||
* 1. `BM10BN10ThreadClusterBM10Xs[1] == dpp8::lane_group_size` or
|
||||
* `BM10BN10ThreadClusterBN10Xs[1] == dpp8::lane_group_size` -
|
||||
* - it makes consecutive `dpp8::lane_group_size` threads use the same chunk of either
|
||||
* matrix A or B;
|
||||
* - based on these values we determine which matrix to share.
|
||||
* 2. `BM1PerThreadBM11 % dpp8::lane_group_size == 0` (if sharing A) or
|
||||
* `BN1PerThreadBN11 % dpp8::lane_group_size == 0` (if sharing B) -
|
||||
* - we have to make sure that the data to split is divisible by the number of
|
||||
* threads in the group.
|
||||
*
|
||||
* General algorithm:
|
||||
* C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
|
||||
* A and B are visible to the whole block, C is distributed among each thread
|
||||
* Assume:
|
||||
* 1. A:
|
||||
* 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
|
||||
* 2. ABlockBuffer is DynamicBuffer
|
||||
* 2. B:
|
||||
* 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
|
||||
* 2. BBlockBuffer is DynamicBuffer
|
||||
* 3. C:
|
||||
* 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
|
||||
* 2. CThreadBuffer is StaticBuffer
|
||||
* 4. BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
|
||||
*/
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ABlockDesc_BK0_BM_BK1,
|
||||
typename BBlockDesc_BK0_BN_BK1,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
|
||||
// BM10BN10ThreadClusterBM101, ...>
|
||||
typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
|
||||
// BM10BN10ThreadClusterBN101, ...>
|
||||
index_t AThreadCopyScalarPerVector_BM11,
|
||||
index_t BThreadCopyScalarPerVector_BN11,
|
||||
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0
|
||||
{
|
||||
using AIndex = MultiIndex<4>;
|
||||
using BIndex = MultiIndex<4>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
|
||||
static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
|
||||
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
|
||||
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
|
||||
|
||||
static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
|
||||
static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
|
||||
|
||||
static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
|
||||
static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
|
||||
|
||||
static constexpr index_t BM11 = BM1PerThreadBM11;
|
||||
static constexpr index_t BN11 = BN1PerThreadBN11;
|
||||
|
||||
static constexpr index_t BM1 = BM100 * BM101 * BM11;
|
||||
static constexpr index_t BN1 = BN100 * BN101 * BN11;
|
||||
|
||||
static constexpr index_t BM0 = BM / BM1;
|
||||
static constexpr index_t BN0 = BN / BN1;
|
||||
|
||||
// We assume that either `BM101` or `BN101` is equal to `dpp8::lane_group_size`. It makes all
|
||||
// threads in a lane group need the same chunk of B or A matrices and we can share them using
|
||||
// DPP.
|
||||
static_assert(BM101 == dpp8::lane_group_size || BN101 == dpp8::lane_group_size);
|
||||
static constexpr bool ShareB = BM101 == dpp8::lane_group_size ? true : false;
|
||||
static constexpr bool ShareA = !ShareB;
|
||||
|
||||
// If DPP shares A (B, respectively), lane group gets `BM1PerThreadBM11` (`BN1PerThreadBN11`,
|
||||
// respectively) elements, so we split them between threads in lane group so each thread loads
|
||||
// less data from LDS.
|
||||
static constexpr index_t BM1PerThread =
|
||||
ShareA ? BM1PerThreadBM11 / dpp8::lane_group_size : BM1PerThreadBM11;
|
||||
static constexpr index_t BN1PerThread =
|
||||
ShareB ? BN1PerThreadBN11 / dpp8::lane_group_size : BN1PerThreadBN11;
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
|
||||
{
|
||||
const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
|
||||
a_block_desc_bk0_bm_bk1,
|
||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
|
||||
make_pass_through_transform(Number<BK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return a_block_bk0_bm0_bm1_bk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
|
||||
{
|
||||
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
|
||||
b_block_desc_bk0_bn_bk1,
|
||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
|
||||
make_pass_through_transform(Number<BK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return b_block_desc_bk0_bn0_bn1_bk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
|
||||
{
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// lower: [BM, BN]
|
||||
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<BM0>{}, Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<BN0>{}, Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
|
||||
|
||||
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
|
||||
{
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// lower: [BM0, BM1, BN0, BN1]
|
||||
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(Number<BM0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
|
||||
make_pass_through_transform(Number<BN0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
|
||||
|
||||
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
|
||||
{
|
||||
return Sequence<BM0, BM11, BN0, BN11>{};
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
|
||||
MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
|
||||
|
||||
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0()
|
||||
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()},
|
||||
b_thread_copy_{CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()}
|
||||
{
|
||||
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
|
||||
|
||||
static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
|
||||
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
|
||||
BM10BN10ThreadClusterBN10Xs::Size() == 2,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
|
||||
{
|
||||
// lower: [BM0, BM1, BN0, BN1]
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
constexpr auto adaptor0 =
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1();
|
||||
|
||||
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// upper: [Tid, BM0, BM11, BN0, BN11]
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)),
|
||||
make_pass_through_transform(BM0),
|
||||
make_pass_through_transform(BM11),
|
||||
make_pass_through_transform(BN0),
|
||||
make_pass_through_transform(BN11)),
|
||||
make_tuple(
|
||||
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
|
||||
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
__device__ AIndex CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()
|
||||
{
|
||||
const auto offsetBM0 = c_thread_origin_data_idx_[I0];
|
||||
// If sharing matrix A, we need a separate BM1 offset for each thread in lane group.
|
||||
const auto offsetBM1 = ShareA ? c_thread_origin_data_idx_[I1] +
|
||||
dpp8::get_thread_idx_in_lane_group() * BM1PerThread
|
||||
: c_thread_origin_data_idx_[I1];
|
||||
return make_tuple(0, offsetBM0, offsetBM1, 0);
|
||||
}
|
||||
|
||||
__device__ BIndex CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()
|
||||
{
|
||||
const auto offsetBN0 = c_thread_origin_data_idx_[I2];
|
||||
// If sharing matrix B, we need a separate BN1 offset for each thread in lane group.
|
||||
const auto offsetBN1 = ShareB ? c_thread_origin_data_idx_[I3] +
|
||||
dpp8::get_thread_idx_in_lane_group() * BN1PerThread
|
||||
: c_thread_origin_data_idx_[I3];
|
||||
return make_tuple(0, offsetBN0, offsetBN1, 0);
|
||||
}
|
||||
|
||||
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
|
||||
const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_contraction =
|
||||
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
|
||||
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
|
||||
CThreadDesc_BM0_BM11_BN0_BN11,
|
||||
Sequence<BK0PerThread, BK1>,
|
||||
Sequence<1, BM1PerThreadBM11>,
|
||||
Sequence<1, BN1PerThreadBN11>,
|
||||
ShareA>{};
|
||||
|
||||
static_for<0, BN0, 1>{}([&](auto bn0) {
|
||||
static_for<0, BM0, 1>{}([&](auto bm0) {
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, bm0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, bn0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(bm0, I0, bn0, I0));
|
||||
|
||||
static_for<BK0PerThread, BK0, BK0PerThread>{}([&](auto bk0) {
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(bk0, bm0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(bk0, bn0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(bm0, I0, bn0, I0));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
// A[BK0, BM0, BM1, BK1]
|
||||
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThread>{}, Number<BK1>{}));
|
||||
|
||||
// B[BK0, BN0, BN1, BK1]
|
||||
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThread>{}, Number<BK1>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
|
||||
FloatA,
|
||||
FloatA,
|
||||
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
|
||||
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
|
||||
Sequence<BK0PerThread, 1, BM1PerThread, BK1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, BM1PerThread, BK1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
|
||||
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
|
||||
Sequence<BK0PerThread, 1, BN1PerThread, BK1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, BN1PerThread, BK1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
348
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
Normal file
348
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
Normal file
@@ -0,0 +1,348 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/warp/dpp_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each
|
||||
* thread by sharing the data between threads in a lanegroup.
|
||||
*
|
||||
* In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are
|
||||
* `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one.
|
||||
* In total, the algorithm runs using
|
||||
* `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves.
|
||||
*/
|
||||
template <index_t BlockSize,
|
||||
typename ABDataType,
|
||||
typename AccDataType,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t MPerDpp,
|
||||
index_t NPerDpp,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t KPerBlock =
|
||||
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto dpp_gemm = DppGemm<ABDataType, MPerDpp, NPerDpp, KPack>{};
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
MRepeat * NRepeat,
|
||||
dpp_gemm.GetRegSizePerDpp(),
|
||||
true>
|
||||
c_thread_buf_;
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = ThisThreadBlock::GetThreadId();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M();
|
||||
const auto dpp_a_idx_k = dpp_a_idx[I0];
|
||||
const auto dpp_a_idx_m = dpp_a_idx[I1];
|
||||
return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N();
|
||||
const auto dpp_b_idx_k = dpp_b_idx[I0];
|
||||
const auto dpp_b_idx_n = dpp_b_idx[I1];
|
||||
return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0>
|
||||
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk();
|
||||
const auto blk_m_offset = blk_idx[I0];
|
||||
const auto blk_n_offset = blk_idx[I1];
|
||||
|
||||
constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerDpp))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex(
|
||||
make_tuple(m0, waveId_m, blk_m_offset))[I0];
|
||||
const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex(
|
||||
make_tuple(n0, waveId_n, blk_n_offset))[I0];
|
||||
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
__host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2()
|
||||
{
|
||||
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
"Wrong! Block descriptors should be known at the time of compilation.");
|
||||
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
// Host wave size can be different than the device one and this assert could fail for host,
|
||||
// but it does matter only for device.
|
||||
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
#endif
|
||||
|
||||
static_assert(MPerBlock % (MPerDpp * MRepeat) == 0,
|
||||
"Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat.");
|
||||
static_assert(NPerBlock % (NPerDpp * NRepeat) == 0,
|
||||
"Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat.");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2()
|
||||
{
|
||||
constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths();
|
||||
constexpr auto M = c_m_n_tblk_lens[I0];
|
||||
constexpr auto N = c_m_n_tblk_lens[I1];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M, N));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2()
|
||||
{
|
||||
constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths();
|
||||
constexpr auto M = c_m_n_tblk_lens[I0];
|
||||
constexpr auto N = c_m_n_tblk_lens[I1];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M, N));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2()
|
||||
{
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerDpp>{},
|
||||
Number<NPerDpp>{}));
|
||||
|
||||
return c_block_desc_m0_n0_m1_n1_m2_n2;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2()
|
||||
{
|
||||
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1,
|
||||
Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerDpp>{},
|
||||
Number<NPerDpp>{}));
|
||||
return c_block_desc_g_m0_n0_m1_n1_m2_n2;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
|
||||
|
||||
return c_grid_desc_m0_n0_m1_n1_m2_n2;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_G_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
{
|
||||
const auto G = c_grid_desc_g_m_n.GetLength(I0);
|
||||
const auto M = c_grid_desc_g_m_n.GetLength(I1);
|
||||
const auto N = c_grid_desc_g_m_n.GetLength(I2);
|
||||
|
||||
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_g_m_n,
|
||||
make_tuple(make_pass_through_transform(G),
|
||||
make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
|
||||
|
||||
return c_grid_desc_g_m0_n0_m1_n1_m2_n2;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
AK0MK1BlockDesc{},
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerDpp>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
BK0NK1BlockDesc{},
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerDpp>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
|
||||
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
static_for<0, KPerThread, KPack>{}([&](auto k) {
|
||||
vector_type<ABDataType, KPack> a_thread_vec;
|
||||
vector_type<ABDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<ABDataType>()(i) = a_thread_buf
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
|
||||
b_thread_vec.template AsType<ABDataType>()(i) = b_thread_buf
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
|
||||
});
|
||||
|
||||
using dpp_input_type =
|
||||
typename vector_type<ABDataType, dpp_gemm.K1PerDpp>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
dpp_gemm.template Run(a_thread_vec.template AsType<dpp_input_type>(),
|
||||
b_thread_vec.template AsType<dpp_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
protected:
|
||||
// A[M0, M1, M2, KPerThread]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
|
||||
|
||||
// B[N0, N1, N2, KPerThread]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
|
||||
|
||||
// C[M, N, NumRegDpp]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, dpp_gemm.GetRegSizePerDpp()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ABDataType,
|
||||
ABDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerThread>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<ABDataType,
|
||||
ABDataType,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerThread>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()};
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -4,27 +4,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/loop_scheduler.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct LoopScheduler
|
||||
{
|
||||
Default,
|
||||
Interwave,
|
||||
};
|
||||
|
||||
constexpr LoopScheduler make_default_loop_scheduler()
|
||||
{
|
||||
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
|
||||
return LoopScheduler::Interwave;
|
||||
#else
|
||||
return LoopScheduler::Default;
|
||||
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
|
||||
}
|
||||
|
||||
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
|
||||
|
||||
Reference in New Issue
Block a user