mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Redesign the DPP8 GEMM kernel to use warp-wise component (#863)
* Redesign the DPP8 GEMM kernel to use warp-wise component * Review: Improve error messages * Review: Remove unnecessary empty lines * Review: Fix M, N per thread names * Review: Rename mfma_input_type to dpp_input_type * Review: Fix tensor adaptor; remove unnecessary element * Review: Remove calls to dpp_gemm's MakeCDescriptor * Review: Add blockwise doc, change function names to include dimension names * Review: Remove duplicated code; Move Block2CtileMap alias to the top of the file * Review: Add __restrict__ keywords * Review: Use MatrixPadder for padding A, B, C matrices * Review: Remove hardcoded datatypes * Review: Change names from FloatX to XDataType * Review: Introduce AK0 and BK0 instead of a single K0 * Review: Remove construction of dpp_datatypes object * Review: Rename DppInstrRunner to DppLanegroupGemm
This commit is contained in:
committed by
GitHub
parent
3786bfe1cc
commit
37a8c1f756
@@ -1,370 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/amd_gemm_dpp.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* DPP8 version of blockwise GEMM algorithm. It uses DPP8 instruction modifier to limit
|
||||
* the data loaded from LDS to registers.
|
||||
*
|
||||
* The algorithm groups threads into groups of size `dpp8::lane_group_size` and splits the matrix C
|
||||
* between them in such a way that threads from the same group need the same chunk of either
|
||||
* matrix A (or B, respectively). Without the usage of DPP8, each thread would need to load the
|
||||
* whole chunk from LDS to its own register space.
|
||||
* Usage of DPP8 modifiers allow each thread to load less data, exactly `1 / dpp8::lane_group_size`
|
||||
* of the chunk, and then share that data with other threads from the same lane group.
|
||||
*
|
||||
* Assumptions coming from the usage of DPP8:
|
||||
* 1. `BM10BN10ThreadClusterBM10Xs[1] == dpp8::lane_group_size` or
|
||||
* `BM10BN10ThreadClusterBN10Xs[1] == dpp8::lane_group_size` -
|
||||
* - it makes consecutive `dpp8::lane_group_size` threads use the same chunk of either
|
||||
* matrix A or B;
|
||||
* - based on these values we determine which matrix to share.
|
||||
* 2. `BM1PerThreadBM11 % dpp8::lane_group_size == 0` (if sharing A) or
|
||||
* `BN1PerThreadBN11 % dpp8::lane_group_size == 0` (if sharing B) -
|
||||
* - we have to make sure that the data to split is divisible by the number of
|
||||
* threads in the group.
|
||||
*
|
||||
* General algorithm:
|
||||
* C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
|
||||
* A and B are visible to the whole block, C is distributed among each thread
|
||||
* Assume:
|
||||
* 1. A:
|
||||
* 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
|
||||
* 2. ABlockBuffer is DynamicBuffer
|
||||
* 2. B:
|
||||
* 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
|
||||
* 2. BBlockBuffer is DynamicBuffer
|
||||
* 3. C:
|
||||
* 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
|
||||
* 2. CThreadBuffer is StaticBuffer
|
||||
* 4. BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
|
||||
*/
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ABlockDesc_BK0_BM_BK1,
|
||||
typename BBlockDesc_BK0_BN_BK1,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
|
||||
// BM10BN10ThreadClusterBM101, ...>
|
||||
typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
|
||||
// BM10BN10ThreadClusterBN101, ...>
|
||||
index_t AThreadCopyScalarPerVector_BM11,
|
||||
index_t BThreadCopyScalarPerVector_BN11,
|
||||
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0
|
||||
{
|
||||
using AIndex = MultiIndex<4>;
|
||||
using BIndex = MultiIndex<4>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
|
||||
static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
|
||||
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
|
||||
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
|
||||
|
||||
static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
|
||||
static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
|
||||
|
||||
static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
|
||||
static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
|
||||
|
||||
static constexpr index_t BM11 = BM1PerThreadBM11;
|
||||
static constexpr index_t BN11 = BN1PerThreadBN11;
|
||||
|
||||
static constexpr index_t BM1 = BM100 * BM101 * BM11;
|
||||
static constexpr index_t BN1 = BN100 * BN101 * BN11;
|
||||
|
||||
static constexpr index_t BM0 = BM / BM1;
|
||||
static constexpr index_t BN0 = BN / BN1;
|
||||
|
||||
// We assume that either `BM101` or `BN101` is equal to `dpp8::lane_group_size`. It makes all
|
||||
// threads in a lane group need the same chunk of B or A matrices and we can share them using
|
||||
// DPP.
|
||||
static_assert(BM101 == dpp8::lane_group_size || BN101 == dpp8::lane_group_size);
|
||||
static constexpr bool ShareB = BM101 == dpp8::lane_group_size ? true : false;
|
||||
static constexpr bool ShareA = !ShareB;
|
||||
|
||||
// If DPP shares A (B, respectively), lane group gets `BM1PerThreadBM11` (`BN1PerThreadBN11`,
|
||||
// respectively) elements, so we split them between threads in lane group so each thread loads
|
||||
// less data from LDS.
|
||||
static constexpr index_t BM1PerThread =
|
||||
ShareA ? BM1PerThreadBM11 / dpp8::lane_group_size : BM1PerThreadBM11;
|
||||
static constexpr index_t BN1PerThread =
|
||||
ShareB ? BN1PerThreadBN11 / dpp8::lane_group_size : BN1PerThreadBN11;
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
|
||||
{
|
||||
const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
|
||||
a_block_desc_bk0_bm_bk1,
|
||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
|
||||
make_pass_through_transform(Number<BK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return a_block_bk0_bm0_bm1_bk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
|
||||
{
|
||||
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
|
||||
b_block_desc_bk0_bn_bk1,
|
||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
|
||||
make_pass_through_transform(Number<BK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return b_block_desc_bk0_bn0_bn1_bk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
|
||||
{
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// lower: [BM, BN]
|
||||
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<BM0>{}, Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<BN0>{}, Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
|
||||
|
||||
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
|
||||
{
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// lower: [BM0, BM1, BN0, BN1]
|
||||
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(Number<BM0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
|
||||
make_pass_through_transform(Number<BN0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
|
||||
|
||||
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
|
||||
{
|
||||
return Sequence<BM0, BM11, BN0, BN11>{};
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
|
||||
MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
|
||||
|
||||
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0()
|
||||
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()},
|
||||
b_thread_copy_{CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()}
|
||||
{
|
||||
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
|
||||
|
||||
static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
|
||||
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
|
||||
BM10BN10ThreadClusterBN10Xs::Size() == 2,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
|
||||
{
|
||||
// lower: [BM0, BM1, BN0, BN1]
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
constexpr auto adaptor0 =
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1();
|
||||
|
||||
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// upper: [Tid, BM0, BM11, BN0, BN11]
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)),
|
||||
make_pass_through_transform(BM0),
|
||||
make_pass_through_transform(BM11),
|
||||
make_pass_through_transform(BN0),
|
||||
make_pass_through_transform(BN11)),
|
||||
make_tuple(
|
||||
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
|
||||
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
__device__ AIndex CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()
|
||||
{
|
||||
const auto offsetBM0 = c_thread_origin_data_idx_[I0];
|
||||
// If sharing matrix A, we need a separate BM1 offset for each thread in lane group.
|
||||
const auto offsetBM1 = ShareA ? c_thread_origin_data_idx_[I1] +
|
||||
dpp8::get_thread_idx_in_lane_group() * BM1PerThread
|
||||
: c_thread_origin_data_idx_[I1];
|
||||
return make_tuple(0, offsetBM0, offsetBM1, 0);
|
||||
}
|
||||
|
||||
__device__ BIndex CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()
|
||||
{
|
||||
const auto offsetBN0 = c_thread_origin_data_idx_[I2];
|
||||
// If sharing matrix B, we need a separate BN1 offset for each thread in lane group.
|
||||
const auto offsetBN1 = ShareB ? c_thread_origin_data_idx_[I3] +
|
||||
dpp8::get_thread_idx_in_lane_group() * BN1PerThread
|
||||
: c_thread_origin_data_idx_[I3];
|
||||
return make_tuple(0, offsetBN0, offsetBN1, 0);
|
||||
}
|
||||
|
||||
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
|
||||
const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_contraction =
|
||||
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
|
||||
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
|
||||
CThreadDesc_BM0_BM11_BN0_BN11,
|
||||
Sequence<BK0PerThread, BK1>,
|
||||
Sequence<1, BM1PerThreadBM11>,
|
||||
Sequence<1, BN1PerThreadBN11>,
|
||||
ShareA>{};
|
||||
|
||||
static_for<0, BN0, 1>{}([&](auto bn0) {
|
||||
static_for<0, BM0, 1>{}([&](auto bm0) {
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, bm0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, bn0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(bm0, I0, bn0, I0));
|
||||
|
||||
static_for<BK0PerThread, BK0, BK0PerThread>{}([&](auto bk0) {
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(bk0, bm0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(bk0, bn0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(bm0, I0, bn0, I0));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
// A[BK0, BM0, BM1, BK1]
|
||||
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThread>{}, Number<BK1>{}));
|
||||
|
||||
// B[BK0, BN0, BN1, BK1]
|
||||
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThread>{}, Number<BK1>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
|
||||
FloatA,
|
||||
FloatA,
|
||||
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
|
||||
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
|
||||
Sequence<BK0PerThread, 1, BM1PerThread, BK1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, BM1PerThread, BK1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
|
||||
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
|
||||
Sequence<BK0PerThread, 1, BN1PerThread, BK1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, BN1PerThread, BK1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
348
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
Normal file
348
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
Normal file
@@ -0,0 +1,348 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/warp/dpp_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each
|
||||
* thread by sharing the data between threads in a lanegroup.
|
||||
*
|
||||
* In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are
|
||||
* `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one.
|
||||
* In total, the algorithm runs using
|
||||
* `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves.
|
||||
*/
|
||||
template <index_t BlockSize,
|
||||
typename ABDataType,
|
||||
typename AccDataType,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t MPerDpp,
|
||||
index_t NPerDpp,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t KPerBlock =
|
||||
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto dpp_gemm = DppGemm<ABDataType, MPerDpp, NPerDpp, KPack>{};
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
MRepeat * NRepeat,
|
||||
dpp_gemm.GetRegSizePerDpp(),
|
||||
true>
|
||||
c_thread_buf_;
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = ThisThreadBlock::GetThreadId();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M();
|
||||
const auto dpp_a_idx_k = dpp_a_idx[I0];
|
||||
const auto dpp_a_idx_m = dpp_a_idx[I1];
|
||||
return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N();
|
||||
const auto dpp_b_idx_k = dpp_b_idx[I0];
|
||||
const auto dpp_b_idx_n = dpp_b_idx[I1];
|
||||
return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0>
|
||||
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk();
|
||||
const auto blk_m_offset = blk_idx[I0];
|
||||
const auto blk_n_offset = blk_idx[I1];
|
||||
|
||||
constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerDpp))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex(
|
||||
make_tuple(m0, waveId_m, blk_m_offset))[I0];
|
||||
const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex(
|
||||
make_tuple(n0, waveId_n, blk_n_offset))[I0];
|
||||
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
__host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2()
|
||||
{
|
||||
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
"Wrong! Block descriptors should be known at the time of compilation.");
|
||||
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
// Host wave size can be different than the device one and this assert could fail for host,
|
||||
// but it does matter only for device.
|
||||
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
#endif
|
||||
|
||||
static_assert(MPerBlock % (MPerDpp * MRepeat) == 0,
|
||||
"Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat.");
|
||||
static_assert(NPerBlock % (NPerDpp * NRepeat) == 0,
|
||||
"Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat.");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2()
|
||||
{
|
||||
constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths();
|
||||
constexpr auto M = c_m_n_tblk_lens[I0];
|
||||
constexpr auto N = c_m_n_tblk_lens[I1];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M, N));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2()
|
||||
{
|
||||
constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths();
|
||||
constexpr auto M = c_m_n_tblk_lens[I0];
|
||||
constexpr auto N = c_m_n_tblk_lens[I1];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M, N));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2()
|
||||
{
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerDpp>{},
|
||||
Number<NPerDpp>{}));
|
||||
|
||||
return c_block_desc_m0_n0_m1_n1_m2_n2;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2()
|
||||
{
|
||||
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1,
|
||||
Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerDpp>{},
|
||||
Number<NPerDpp>{}));
|
||||
return c_block_desc_g_m0_n0_m1_n1_m2_n2;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
|
||||
|
||||
return c_grid_desc_m0_n0_m1_n1_m2_n2;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_G_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
{
|
||||
const auto G = c_grid_desc_g_m_n.GetLength(I0);
|
||||
const auto M = c_grid_desc_g_m_n.GetLength(I1);
|
||||
const auto N = c_grid_desc_g_m_n.GetLength(I2);
|
||||
|
||||
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_g_m_n,
|
||||
make_tuple(make_pass_through_transform(G),
|
||||
make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
|
||||
|
||||
return c_grid_desc_g_m0_n0_m1_n1_m2_n2;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
AK0MK1BlockDesc{},
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerDpp>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
BK0NK1BlockDesc{},
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerDpp>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
|
||||
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ABDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
static_for<0, KPerThread, KPack>{}([&](auto k) {
|
||||
vector_type<ABDataType, KPack> a_thread_vec;
|
||||
vector_type<ABDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<ABDataType>()(i) = a_thread_buf
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
|
||||
b_thread_vec.template AsType<ABDataType>()(i) = b_thread_buf
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
|
||||
});
|
||||
|
||||
using dpp_input_type =
|
||||
typename vector_type<ABDataType, dpp_gemm.K1PerDpp>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
dpp_gemm.template Run(a_thread_vec.template AsType<dpp_input_type>(),
|
||||
b_thread_vec.template AsType<dpp_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
protected:
|
||||
// A[M0, M1, M2, KPerThread]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
|
||||
|
||||
// B[N0, N1, N2, KPerThread]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
|
||||
|
||||
// C[M, N, NumRegDpp]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, dpp_gemm.GetRegSizePerDpp()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ABDataType,
|
||||
ABDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerThread>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<ABDataType,
|
||||
ABDataType,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerThread>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()};
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -4,27 +4,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/loop_scheduler.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct LoopScheduler
|
||||
{
|
||||
Default,
|
||||
Interwave,
|
||||
};
|
||||
|
||||
constexpr LoopScheduler make_default_loop_scheduler()
|
||||
{
|
||||
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
|
||||
return LoopScheduler::Interwave;
|
||||
#else
|
||||
return LoopScheduler::Default;
|
||||
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
|
||||
}
|
||||
|
||||
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum struct GemmDlAlgorithm
|
||||
{
|
||||
Default, // Uses DOT vector instructions
|
||||
Dpp8, // Uses DOT vector instructions with DPP8 SEL modifier to reduce data loads from LDS
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -11,7 +11,6 @@
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -60,7 +59,6 @@ template <
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default,
|
||||
enable_if_t<
|
||||
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
@@ -238,8 +236,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
GemmDlAlg>;
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
|
||||
using AGridDesc_K0_M0_M1_K1 =
|
||||
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
|
||||
@@ -375,8 +372,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
|
||||
remove_reference_t<DefaultBlock2CTileMap>,
|
||||
true,
|
||||
true,
|
||||
GemmDlAlg>;
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
@@ -402,8 +398,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
|
||||
remove_reference_t<DefaultBlock2CTileMap>,
|
||||
true,
|
||||
false,
|
||||
GemmDlAlg>;
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
@@ -429,8 +424,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
|
||||
remove_reference_t<DefaultBlock2CTileMap>,
|
||||
false,
|
||||
true,
|
||||
GemmDlAlg>;
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
@@ -456,8 +450,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
|
||||
remove_reference_t<DefaultBlock2CTileMap>,
|
||||
false,
|
||||
false,
|
||||
GemmDlAlg>;
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
@@ -492,16 +485,6 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx1030")
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102")
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t K1,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
typename M1N1ThreadClusterM1Xs,
|
||||
typename M1N1ThreadClusterN1Xs,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
enable_if_t<
|
||||
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
|
||||
bool> = false>
|
||||
struct DeviceGemmDlDpp8 : public DeviceGemmDl<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
K1,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
GemmDlAlgorithm::Dpp8>
|
||||
|
||||
{
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmDlDpp8"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< M1PerThread << ", "
|
||||
<< N1PerThread << ", "
|
||||
<< KPerThread
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
271
include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp
Normal file
271
include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp
Normal file
@@ -0,0 +1,271 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerDpp,
|
||||
ck::index_t NPerDpp,
|
||||
ck::index_t MDppPerWave,
|
||||
ck::index_t NDppPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGemmDpp : public DeviceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
using GridwiseGemm = GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerDpp,
|
||||
NPerDpp,
|
||||
AK1,
|
||||
BK1,
|
||||
MDppPerWave,
|
||||
NDppPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
NumPrefetch,
|
||||
PipelineVer>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
karg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(karg))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting");
|
||||
}
|
||||
|
||||
const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_dpp<GridwiseGemm, true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_dpp<GridwiseGemm, false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& karg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx1030")
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(karg);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
{
|
||||
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
|
||||
{PipelineVersion::v2, "v2"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmDpp"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerDpp << ", "
|
||||
<< NPerDpp << ", "
|
||||
<< MDppPerWave << ", "
|
||||
<< MDppPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< ABlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferDstScalarPerVector_K1
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "PipelineVersion: "
|
||||
<< PipelineVersionToString[PipelineVer];
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -7,11 +7,9 @@
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
|
||||
@@ -19,8 +17,6 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm;
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
@@ -29,8 +25,7 @@ template <typename GridwiseGemm,
|
||||
typename CGridDesc_M0_M10_M11_N0_N10_N11,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop,
|
||||
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
@@ -43,13 +38,6 @@ __global__ void
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
// DPP8 is currently only supported on gfx1030
|
||||
#if !defined(__gfx1030__)
|
||||
if(GemmDlAlg == GemmDlAlgorithm::Dpp8)
|
||||
{
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
@@ -100,8 +88,7 @@ template <index_t BlockSize,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default>
|
||||
index_t CThreadTransferDstScalarPerVector>
|
||||
struct GridwiseGemmDl_km_kn_mn_v1r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -257,45 +244,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
template <typename ABlockDesc_BK0_BM_BK1, typename BBlockDesc_BK0_BN_BK1>
|
||||
__host__ __device__ static constexpr auto GetBlockwiseGemm()
|
||||
{
|
||||
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8)
|
||||
{
|
||||
return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
ABlockDesc_BK0_BM_BK1,
|
||||
BBlockDesc_BK0_BN_BK1,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM110Xs,
|
||||
M11N11ThreadClusterN110Xs,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
ABlockDesc_BK0_BM_BK1,
|
||||
BBlockDesc_BK0_BN_BK1,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM110Xs,
|
||||
M11N11ThreadClusterN110Xs,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
}
|
||||
}
|
||||
|
||||
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
|
||||
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
|
||||
using CGridDesc_M0_M10_M11_N0_N10_N11 =
|
||||
@@ -424,7 +372,20 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
GetBlockwiseGemm<decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc)>();
|
||||
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM110Xs,
|
||||
M11N11ThreadClusterN110Xs,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
|
||||
701
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp
Normal file
701
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp
Normal file
@@ -0,0 +1,701 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm, bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
#if CK_USE_WAVES_PER_EU
|
||||
__attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
|
||||
#endif
|
||||
kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane(
|
||||
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA));
|
||||
const auto b_grid_desc_bk0_n_bk1 = amd_wave_read_first_lane(
|
||||
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB));
|
||||
const auto c_grid_desc_m_n = amd_wave_read_first_lane(
|
||||
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC));
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_m_n);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ABDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerDpp,
|
||||
index_t NPerDpp,
|
||||
index_t AK1Value,
|
||||
index_t BK1Value,
|
||||
index_t MDppPerWave,
|
||||
index_t NDppPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BBlockLdsExtraN,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
struct GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static constexpr auto AK1 = Number<AK1Value>{};
|
||||
static constexpr auto BK1 = Number<BK1Value>{};
|
||||
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
|
||||
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
|
||||
|
||||
static constexpr auto max_lds_align = math::lcm(AK1, BK1);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
{
|
||||
return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateNPadded(index_t N)
|
||||
{
|
||||
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateAK0(index_t K) { return math::integer_divide_floor(K, AK1Value); }
|
||||
__host__ static auto CalculateBK0(index_t K) { return math::integer_divide_floor(K, BK1Value); }
|
||||
|
||||
// Argument
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideC{StrideC_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
NPadded{CalculateNPadded(N_)},
|
||||
AK0{CalculateAK0(K)},
|
||||
BK0{CalculateBK0(K)}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ void Print() const
|
||||
{
|
||||
std::cout << "problem {"
|
||||
<< "M:" << M << ", "
|
||||
<< "N:" << N << ", "
|
||||
<< "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", "
|
||||
<< "SB:" << StrideB << ", "
|
||||
<< "SC:" << StrideC << ", "
|
||||
<< "MP:" << MPadded << ", "
|
||||
<< "NP:" << NPadded << ", "
|
||||
<< "AK0:" << AK0 << ", "
|
||||
<< "BK0:" << BK0 << "}" << std::endl;
|
||||
}
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
index_t StrideC;
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
index_t AK0;
|
||||
index_t BK0;
|
||||
};
|
||||
|
||||
// Argument
|
||||
struct Argument : public Problem, public tensor_operation::device::BaseArgument
|
||||
{
|
||||
__host__ Argument(const ABDataType* p_a_grid_,
|
||||
const ABDataType* p_b_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_}
|
||||
{
|
||||
}
|
||||
|
||||
const ABDataType* p_a_grid;
|
||||
const ABDataType* p_b_grid;
|
||||
CDataType* p_c_grid;
|
||||
};
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<AK0PerBlock>{}, Number<MPerBlock>{}, AK1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * AK1, AK1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<AK0PerBlock>{}, Number<MPerBlock>{}, AK1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_desc_ak0_m_ak1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
|
||||
{
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<BK0PerBlock>{}, Number<NPerBlock>{}, BK1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * BK1, BK1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<BK0PerBlock>{}, Number<NPerBlock>{}, BK1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_desc_bk0_n_bk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType);
|
||||
}
|
||||
|
||||
__host__ static constexpr bool CheckValidity(const Problem& problem)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value,
|
||||
"Wrong! AK1 must be known at the time of compilation.");
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
|
||||
"Wrong! BK1 must be known at the time of compilation.");
|
||||
|
||||
static_assert(
|
||||
MPerBlock % (MPerDpp * MDppPerWave) == 0,
|
||||
"Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave.");
|
||||
static_assert(
|
||||
NPerBlock % (NPerDpp * NDppPerWave) == 0,
|
||||
"Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave.");
|
||||
|
||||
static_assert(
|
||||
KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
|
||||
"Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1.");
|
||||
|
||||
static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0,
|
||||
"Invalid tuning parameters! AK1Value must be divisible by "
|
||||
"ABlockTransferDstScalarPerVector_K1");
|
||||
|
||||
static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0,
|
||||
"Invalid tuning parameters! BK1Value must be divisible by "
|
||||
"BBlockTransferDstScalarPerVector_K1");
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
{
|
||||
if(!(problem.M % MPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
{
|
||||
if(!(problem.N % NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
if(problem.K % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(problem.M % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
if(problem.N % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(problem.K % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(problem.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = problem.K / KPerBlock;
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const auto num_loop = K / KPerBlock;
|
||||
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
template <typename CGridDesc>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc& c_grid_desc_m_n)
|
||||
{
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), DppSelector<ABDataType, MPerDpp, NPerDpp>::selected_dpp.k_per_dpp);
|
||||
|
||||
using BlockwiseGemm =
|
||||
BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2<BlockSize,
|
||||
ABDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerDpp,
|
||||
NPerDpp,
|
||||
MDppPerWave,
|
||||
NDppPerWave,
|
||||
KPack>;
|
||||
|
||||
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
__device__ static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
__device__ static auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_unmerge_transform(make_tuple(BK0, BK1Value))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
}
|
||||
|
||||
__device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDesc_M_N>
|
||||
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
CDataType* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 =
|
||||
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n);
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
|
||||
const auto block_2_ctile_map =
|
||||
Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)};
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I0),
|
||||
c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I1))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0PerBlock, MPerBlock, AK1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABDataType,
|
||||
ABDataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumGemmKPrefetchStage>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<BK0PerBlock, NPerBlock, BK1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
ABDataType,
|
||||
ABDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumGemmKPrefetchStage>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[AK0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[BK0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), DppSelector<ABDataType, MPerDpp, NPerDpp>::selected_dpp.k_per_dpp);
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2<BlockSize,
|
||||
ABDataType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerDpp,
|
||||
NPerDpp,
|
||||
MDppPerWave,
|
||||
NDppPerWave,
|
||||
KPack>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(I0);
|
||||
// (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock)
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 / AK0PerBlock);
|
||||
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2 =
|
||||
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2();
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
|
||||
|
||||
constexpr auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
|
||||
constexpr auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_grid =
|
||||
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_grid_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_grid_idx =
|
||||
m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_grid));
|
||||
|
||||
const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_grid_idx =
|
||||
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_grid));
|
||||
|
||||
auto c_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2),
|
||||
decltype(c_grid_desc_m0_n0_m1_n1_m2_n2),
|
||||
CElementwiseOperation,
|
||||
Sequence<M0, N0, I1, I1, MPerThread, NPerThread>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
c_grid_desc_m0_n0_m1_n1_m2_n2,
|
||||
make_multi_index(m_thread_data_on_grid_idx[I0],
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
n_thread_data_on_grid_idx[I2]),
|
||||
c_element_op};
|
||||
|
||||
c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_grid_desc_m0_n0_m1_n1_m2_n2,
|
||||
c_grid_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -4,7 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
|
||||
#include "ck/utility/loop_scheduler.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/amd_gemm_dpp.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/inner_product_dpp8.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* Threadwise contraction using dot instructions with DPP8 modifier.
|
||||
*
|
||||
* Assumptions:
|
||||
* 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1`
|
||||
* are known at compile-time;
|
||||
* 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time;
|
||||
* 3. `TM0` is equal to 1 and `TN0` is equal to 1;
|
||||
* 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by
|
||||
* the size of the lane group (`dpp8::lane_group_size`).
|
||||
*/
|
||||
template <typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AThreadDesc_TK0_TM0_TM1_TK1,
|
||||
typename BThreadDesc_TK0_TN0_TN1_TK1,
|
||||
typename CThreadDesc_TM0_TM1_TN0_TN1,
|
||||
typename TKLengths,
|
||||
typename TMLengths,
|
||||
typename TNLengths,
|
||||
bool ShareA,
|
||||
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
|
||||
{
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t TK0 = TKLengths{}[I0];
|
||||
static constexpr index_t TK1 = TKLengths{}[I1];
|
||||
static constexpr index_t TM0 = TMLengths{}[I0];
|
||||
static constexpr index_t TM1 = TMLengths{}[I1];
|
||||
static constexpr index_t TN0 = TNLengths{}[I0];
|
||||
static constexpr index_t TN1 = TNLengths{}[I1];
|
||||
|
||||
static_assert(TM0 == 1 && TN0 == 1);
|
||||
|
||||
static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) ||
|
||||
(!ShareA && TN1 % dpp8::lane_group_size == 0));
|
||||
static constexpr index_t shared_elems_per_lane =
|
||||
ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size;
|
||||
|
||||
__device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
|
||||
{
|
||||
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
template <typename ABuffer,
|
||||
typename AOriginIdx,
|
||||
typename BBuffer,
|
||||
typename BOriginIdx,
|
||||
typename CBuffer,
|
||||
typename COriginIdx>
|
||||
__device__ static void Run(const ABuffer& a_buf,
|
||||
AOriginIdx,
|
||||
const BBuffer& b_buf,
|
||||
BOriginIdx,
|
||||
CBuffer& c_buf,
|
||||
COriginIdx)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
|
||||
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
|
||||
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, TK0, 1>{}([&](auto tk0) {
|
||||
static_for<0, TM1, 1>{}([&](auto tm1) {
|
||||
static_for<0, TN1, 1>{}([&](auto tn1) {
|
||||
vector_type<FloatA, TK1> a_vec;
|
||||
vector_type<FloatB, TK1> b_vec;
|
||||
|
||||
static_for<0, TK1, 1>{}([&](auto tk1) {
|
||||
constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1;
|
||||
constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
|
||||
a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1));
|
||||
|
||||
constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane;
|
||||
constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
|
||||
b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1));
|
||||
|
||||
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
|
||||
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
|
||||
});
|
||||
|
||||
using a_vector_t = typename vector_type<FloatA, TK1>::type;
|
||||
using b_vector_t = typename vector_type<FloatB, TK1>::type;
|
||||
|
||||
constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(0, tm1, 0, tn1));
|
||||
|
||||
constexpr int src_lane =
|
||||
ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size
|
||||
: (tn1 / shared_elems_per_lane) % dpp8::lane_group_size;
|
||||
|
||||
dpp8::inner_product_dpp<a_vector_t, b_vector_t, FloatC, src_lane, ShareA>(
|
||||
a_vec.template AsType<a_vector_t>()[I0],
|
||||
b_vec.template AsType<b_vector_t>()[I0],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
322
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
Normal file
322
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
Normal file
@@ -0,0 +1,322 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/amd_gemm_dpp.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct DppInstr
|
||||
{
|
||||
dpp8_f16_16x16x2 = 0,
|
||||
dpp8_f16_8x32x2,
|
||||
dpp8_f16_32x8x2
|
||||
};
|
||||
|
||||
/**
|
||||
* Structure representing DPP GEMM executed by a single wavefront.
|
||||
*
|
||||
* Each structure instantiation must contain the following fields:
|
||||
* - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the
|
||||
* number of threads in a wavefront;
|
||||
* - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier,
|
||||
* it's 8 in case of DPP8;
|
||||
* - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM
|
||||
* operation;
|
||||
* - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM
|
||||
* operation;
|
||||
* - m_per_lanegroup - size along M dimension that is processed by a single lanegroup;
|
||||
* - n_per_lanegroup - size along N dimension that is processed by a single lanegroup;
|
||||
* - m_per_thread - size along M dimension of the tile calculated by a single thread;
|
||||
* - n_per_thread - size along N dimension of the tile calculated by a single thread;
|
||||
* - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation;
|
||||
* - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers.
|
||||
*
|
||||
* Not all the combinarions are supported now, for current restrictions see the static asserts
|
||||
* in the DppSelector's contructor.
|
||||
*/
|
||||
template <DppInstr instr>
|
||||
struct dpp_type;
|
||||
|
||||
template <>
|
||||
struct dpp_type<DppInstr::dpp8_f16_32x8x2>
|
||||
{
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t lanegroup_size = 8;
|
||||
static constexpr index_t m_per_wave = 32;
|
||||
static constexpr index_t n_per_wave = 8;
|
||||
static constexpr index_t m_per_lanegroup = 8;
|
||||
static constexpr index_t n_per_lanegroup = 8;
|
||||
static constexpr index_t m_per_thread = 8;
|
||||
static constexpr index_t n_per_thread = 1;
|
||||
static constexpr index_t k_per_dpp = 2;
|
||||
static constexpr bool share_a = true;
|
||||
using BaseType = half_t;
|
||||
|
||||
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
|
||||
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
|
||||
{
|
||||
dpp8::DppLanegroupGemm<m_per_thread,
|
||||
n_per_thread,
|
||||
k_per_dpp,
|
||||
BaseType,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
share_a>{}
|
||||
.Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct dpp_type<DppInstr::dpp8_f16_8x32x2>
|
||||
{
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t lanegroup_size = 8;
|
||||
static constexpr index_t m_per_wave = 8;
|
||||
static constexpr index_t n_per_wave = 32;
|
||||
static constexpr index_t m_per_lanegroup = 8;
|
||||
static constexpr index_t n_per_lanegroup = 8;
|
||||
static constexpr index_t m_per_thread = 8;
|
||||
static constexpr index_t n_per_thread = 1;
|
||||
static constexpr index_t k_per_dpp = 2;
|
||||
static constexpr bool share_a = true;
|
||||
using BaseType = half_t;
|
||||
|
||||
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
|
||||
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
|
||||
{
|
||||
dpp8::DppLanegroupGemm<m_per_thread,
|
||||
n_per_thread,
|
||||
k_per_dpp,
|
||||
BaseType,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
share_a>{}
|
||||
.Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct dpp_type<DppInstr::dpp8_f16_16x16x2>
|
||||
{
|
||||
static constexpr index_t wave_size = 32;
|
||||
static constexpr index_t lanegroup_size = 8;
|
||||
static constexpr index_t m_per_wave = 16;
|
||||
static constexpr index_t n_per_wave = 16;
|
||||
static constexpr index_t m_per_lanegroup = 8;
|
||||
static constexpr index_t n_per_lanegroup = 8;
|
||||
static constexpr index_t m_per_thread = 8;
|
||||
static constexpr index_t n_per_thread = 1;
|
||||
static constexpr index_t k_per_dpp = 2;
|
||||
static constexpr bool share_a = true;
|
||||
using BaseType = half_t;
|
||||
|
||||
template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
|
||||
__device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
|
||||
{
|
||||
dpp8::DppLanegroupGemm<m_per_thread,
|
||||
n_per_thread,
|
||||
k_per_dpp,
|
||||
BaseType,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
share_a>{}
|
||||
.Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
|
||||
struct DppSelector
|
||||
{
|
||||
template <typename BaseType_, index_t MPerDpp_, index_t NPerDpp_>
|
||||
static constexpr auto GetDpp();
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 8, 32>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_8x32x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 16, 16>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_16x16x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 32, 8>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_32x8x2;
|
||||
}
|
||||
|
||||
static constexpr auto selected_dpp = dpp_type<GetDpp<BaseType, MPerDpp, NPerDpp>()>{};
|
||||
|
||||
__host__ __device__ constexpr DppSelector()
|
||||
{
|
||||
static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0);
|
||||
static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0);
|
||||
|
||||
static_assert(selected_dpp.k_per_dpp % 2 == 0);
|
||||
|
||||
static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0);
|
||||
constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size;
|
||||
constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave;
|
||||
constexpr index_t num_dpp_c_elems =
|
||||
selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup;
|
||||
static_assert(num_wave_c_elems % num_dpp_c_elems == 0);
|
||||
static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems);
|
||||
|
||||
if constexpr(selected_dpp.share_a)
|
||||
{
|
||||
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
|
||||
static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0);
|
||||
static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread ==
|
||||
selected_dpp.lanegroup_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0);
|
||||
static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread ==
|
||||
selected_dpp.lanegroup_size);
|
||||
static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread);
|
||||
}
|
||||
|
||||
// Below checks come from the restrictions of the current implementation, could be removed
|
||||
// in the future when the implementation is more generalized.
|
||||
static_assert(selected_dpp.share_a);
|
||||
static_assert(selected_dpp.n_per_thread == 1);
|
||||
static_assert(selected_dpp.m_per_thread == selected_dpp.lanegroup_size);
|
||||
static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
|
||||
static_assert(selected_dpp.n_per_lanegroup ==
|
||||
selected_dpp.n_per_thread * selected_dpp.lanegroup_size);
|
||||
}
|
||||
|
||||
static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; }
|
||||
};
|
||||
|
||||
template <typename BaseType, index_t MPerDpp, index_t NPerDpp, index_t KPack>
|
||||
struct DppGemm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
using CIndex = MultiIndex<2>;
|
||||
using CIndex4D = MultiIndex<4>;
|
||||
|
||||
__host__ __device__ constexpr DppGemm()
|
||||
{
|
||||
static_assert(MPerDpp == 8 || MPerDpp == 16 || MPerDpp == 32,
|
||||
"MPerDpp must be either 8, 16 or 32.");
|
||||
static_assert(NPerDpp == 8 || NPerDpp == 16 || NPerDpp == 32,
|
||||
"NPerDpp must be either 8, 16 or 32.");
|
||||
|
||||
static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp.");
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegSizePerDpp()
|
||||
{
|
||||
return MPerDpp * NPerDpp / dpp_instr.wave_size;
|
||||
}
|
||||
|
||||
template <class ADataType, class BDataType, class CDataType>
|
||||
__device__ void
|
||||
Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const
|
||||
{
|
||||
static_assert(is_same<BaseType, double>::value || is_same<BaseType, float>::value ||
|
||||
is_same<BaseType, half_t>::value || is_same<BaseType, bhalf_t>::value ||
|
||||
is_same<BaseType, int8_t>::value || is_same<BaseType, f8_t>::value,
|
||||
"base BaseType must be double, float, half, bfloat16, and int8_t!");
|
||||
|
||||
static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) {
|
||||
dpp_instr.template run<MPerDpp, NPerDpp>(p_a_wave[k], p_b_wave[k], p_c_thread);
|
||||
});
|
||||
}
|
||||
|
||||
__device__ static auto GetLaneIdInWave()
|
||||
{
|
||||
return get_thread_local_1d_id() % dpp_instr.wave_size;
|
||||
}
|
||||
|
||||
__device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; }
|
||||
|
||||
__device__ static auto GetLaneIdInLaneGroup()
|
||||
{
|
||||
return get_thread_local_1d_id() % dpp_instr.lanegroup_size;
|
||||
}
|
||||
|
||||
__device__ static auto GetLaneGroupIdInWave()
|
||||
{
|
||||
return GetLaneIdInWave() / dpp_instr.lanegroup_size;
|
||||
}
|
||||
|
||||
__device__ static auto GetDppOpIdx()
|
||||
{
|
||||
const auto lanegroupId = GetLaneGroupIdInWave();
|
||||
|
||||
constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup,
|
||||
dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(lanegroupId));
|
||||
|
||||
const auto m_dpp_idx = dpp_idx[I0];
|
||||
const auto n_dpp_idx = dpp_idx[I1];
|
||||
|
||||
return make_tuple(m_dpp_idx, n_dpp_idx);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M()
|
||||
{
|
||||
const auto laneId = get_thread_local_1d_id();
|
||||
const auto wave_row = laneId / dpp_instr.n_per_wave;
|
||||
auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup();
|
||||
return make_tuple(0, m_idx % dpp_instr.m_per_wave);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N()
|
||||
{
|
||||
const auto laneId = get_thread_local_1d_id();
|
||||
return make_tuple(0, laneId % dpp_instr.n_per_wave);
|
||||
}
|
||||
|
||||
__device__ static CIndex GetBeginOfThreadBlk()
|
||||
{
|
||||
const auto dpp_op_idx = GetDppOpIdx();
|
||||
|
||||
const auto m_dpp_op_idx = dpp_op_idx[I0];
|
||||
const auto n_dpp_op_idx = dpp_op_idx[I1];
|
||||
|
||||
index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup();
|
||||
index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup;
|
||||
|
||||
return CIndex{m_offset, n_offset};
|
||||
}
|
||||
|
||||
static constexpr auto dpp = DppSelector<BaseType, MPerDpp, NPerDpp>{};
|
||||
|
||||
static constexpr auto dpp_instr = dpp.selected_dpp;
|
||||
|
||||
static constexpr auto K0PerDpp = 1;
|
||||
static constexpr auto K1PerDpp = dpp.GetK1PerDpp();
|
||||
|
||||
__host__ __device__ static constexpr auto GetCMNThreadBlkLengths()
|
||||
{
|
||||
return make_tuple(Number<dpp_instr.m_per_thread>{}, Number<dpp_instr.n_per_thread>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user