mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[Navi3x-LWPCK-545] Block-wise GEMM + Real GEMM_WMMA_FP16 (#541)
* wmma_op + unit test * add arch limitation to wmma test * change arch limitation * Refactor + Add all type unit test(int4 compile failed) * Add f32_16x16x16_bf16 unit test * tempsave * tempsave * tempsave * runtime bug, cannot find symbol * workaround for incorrect HIP warpSize return value * debugging * tempsave * Correctness OK, waiting for optimization * Tidy up + format * temp save * temp save, reproduce the v_bfi_b32 issue * add inline asm for wmmaop test * tidy up * clean some debug purpose code * discard some codes * clang format * clang format * compiler issue fixed + increase tile size
This commit is contained in:
801
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
Normal file
801
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
Normal file
@@ -0,0 +1,801 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
|
||||
#define CK_MNK_LOOP
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatAcc,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t MPerWMMA,
|
||||
index_t NPerWMMA,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
/* A: K0PerBlock x MPerBlock x K1
|
||||
* B: K0PerBlock x NPerBlock x K1
|
||||
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
|
||||
* KPACK == WMMA_K = 16
|
||||
*/
|
||||
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto WmmaK = Number<16>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
|
||||
static constexpr index_t WaveSize = 32;
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t KPerBlock =
|
||||
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto wmma_gemm =
|
||||
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
MRepeat * NRepeat,
|
||||
wmma_gemm.GetRegSizePerWmma(),
|
||||
true>
|
||||
c_thread_buf_;
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = ThisThreadBlock::GetThreadId();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
|
||||
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
|
||||
// |KRepeat |MRepeat|MWave |MLane |KPack
|
||||
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
|
||||
// |KRepeat |NRepeat|Nwave |NLane |KPack
|
||||
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0>
|
||||
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
|
||||
|
||||
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
|
||||
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
|
||||
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
|
||||
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
|
||||
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle()
|
||||
{
|
||||
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
|
||||
NPerBlock % (NPerWMMA * NRepeat) == 0,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// Thread level, register decriptor. Vector-write
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
||||
|
||||
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
|
||||
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
|
||||
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
||||
// |NThreadPerSubGroup |MAccVgprs
|
||||
make_tuple(Number<MRepeat>{},
|
||||
I1,
|
||||
MSubGroup,
|
||||
Number<NRepeat>{},
|
||||
I1,
|
||||
NThreadPerSubGroup,
|
||||
MAccVgprs));
|
||||
}
|
||||
|
||||
// Provide dimension size
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<MPerWMMA>{},
|
||||
Number<NRepeat>{},
|
||||
Number<NWaves>{},
|
||||
Number<NPerWMMA>{}));
|
||||
|
||||
return wmma_gemm
|
||||
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
||||
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
AK0MK1BlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<A_K0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
BK0NK1BlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<B_K0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
|
||||
}
|
||||
|
||||
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
|
||||
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
|
||||
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
|
||||
static_for<0, WmmaK, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
protected:
|
||||
// A[K0, M0, M1, M2, K1]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<WmmaK / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}));
|
||||
|
||||
// B[K0, N0, N1, N2, K1]
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<WmmaK / B_K1>{}, Number<NRepeat>{}, I1, I1, Number<B_K1>{}));
|
||||
|
||||
// C[M, N, NumRegWMMA]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
|
||||
};
|
||||
|
||||
// block wise level pipe designed for inline asm
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatAcc,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t MPerWMMA,
|
||||
index_t NPerWMMA,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
/* A: K0PerBlock x MPerBlock x K1
|
||||
* B: K0PerBlock x NPerBlock x K1
|
||||
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
|
||||
* KPACK == WMMA_K = 16
|
||||
*/
|
||||
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto WmmaK = Number<16>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
|
||||
static constexpr index_t WaveSize = 32;
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t KPerBlock =
|
||||
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto wmma_gemm =
|
||||
WmmaGemm<FloatA, FloatB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
FloatAcc,
|
||||
MRepeat * NRepeat,
|
||||
wmma_gemm.GetRegSizePerWmma(),
|
||||
true>
|
||||
c_thread_buf_;
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = ThisThreadBlock::GetThreadId();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
|
||||
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
|
||||
// |KRepeat |MRepeat|MWave |MLane |KPack
|
||||
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
|
||||
// |KRepeat |NRepeat|Nwave |NLane |KPack
|
||||
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0>
|
||||
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
|
||||
|
||||
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
|
||||
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
|
||||
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
|
||||
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
|
||||
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO()
|
||||
{
|
||||
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
|
||||
NPerBlock % (NPerWMMA * NRepeat) == 0,
|
||||
"wrong!");
|
||||
}
|
||||
// Thread level, register decriptor. Vector-write
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
||||
|
||||
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
|
||||
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
|
||||
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
||||
// |NThreadPerSubGroup |MAccVgprs
|
||||
make_tuple(Number<MRepeat>{},
|
||||
I1,
|
||||
MSubGroup,
|
||||
Number<NRepeat>{},
|
||||
I1,
|
||||
NThreadPerSubGroup,
|
||||
MAccVgprs));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
|
||||
transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return wmma_gemm
|
||||
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
||||
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
|
||||
}
|
||||
|
||||
// Provide dimension size
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<MPerWMMA>{},
|
||||
Number<NRepeat>{},
|
||||
Number<NWaves>{},
|
||||
Number<NPerWMMA>{}));
|
||||
|
||||
return wmma_gemm
|
||||
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
||||
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
AK0MK1BlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<A_K0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
BK0NK1BlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<B_K0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
|
||||
}
|
||||
|
||||
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
|
||||
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
|
||||
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto RepeatDiff = MRepeat - NRepeat;
|
||||
// Read all Mrepeat, Nrepeat
|
||||
static_for<0, NRepeat, 1>{}([&](auto iN) {
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, Number<iN>{}, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, Number<iN>{}, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto iM) {
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, Number<iM>{}, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, Number<iM>{}, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
|
||||
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
|
||||
static_for<0, RepeatDiff, 1>{}([&](auto iCut) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto iN) {
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
|
||||
static_for<0, WmmaK, 1>{}([&](auto iK) {
|
||||
a_thread_vec.template AsType<FloatA>()(iK) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(iK / A_K1, iCut, 0, 0, iK % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(iK) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}];
|
||||
});
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
|
||||
// s_nop();
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
// s_nop();
|
||||
});
|
||||
if constexpr(KPerBlock > WmmaK)
|
||||
{
|
||||
// Read Consumed Next inner loop A
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<WmmaK / A_K1>{}, Number<iCut>{}, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, Number<iCut>{}, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
}
|
||||
});
|
||||
|
||||
static_for<WmmaK, KPerBlock, WmmaK>{}([&](auto iWmmaK) {
|
||||
// Stage 2: Run FIFO fashion loopover in Square
|
||||
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) {
|
||||
// Row Repeatation
|
||||
static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN) {
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
|
||||
static_for<0, WmmaK, 1>{}([&](auto iK) {
|
||||
a_thread_vec.template AsType<FloatA>()(iK) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
iK / A_K1, WmmaInnerloop + RepeatDiff, 0, 0, iK % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(iK) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}];
|
||||
});
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(WmmaInnerloop + RepeatDiff, iN, 0));
|
||||
// s_nop();
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
// s_nop();
|
||||
});
|
||||
|
||||
// Read Consumed Next inner loop A
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(
|
||||
Number<iWmmaK / A_K1>{}, Number<WmmaInnerloop + RepeatDiff>{}, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, Number<WmmaInnerloop + RepeatDiff>{}, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// Col Repeatation
|
||||
static_for<WmmaInnerloop + 1 + RepeatDiff, MRepeat, 1>{}([&](auto iM) {
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
|
||||
static_for<0, WmmaK, 1>{}([&](auto iK) {
|
||||
a_thread_vec.template AsType<FloatA>()(iK) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(iK / A_K1, iM, 0, 0, iK % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(iK) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(iK / B_K1, WmmaInnerloop, 0, 0, iK % B_K1))>{}];
|
||||
});
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0));
|
||||
// s_nop();
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
// s_nop();
|
||||
});
|
||||
// Read Consumed Next inner loop B
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<iWmmaK / B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
|
||||
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
|
||||
static_for<0, RepeatDiff, 1>{}([&](auto iCut) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto iN) {
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
|
||||
static_for<0, WmmaK, 1>{}([&](auto iK) {
|
||||
a_thread_vec.template AsType<FloatA>()(iK) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(iK / A_K1, iCut, 0, 0, iK % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(iK) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}];
|
||||
});
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
|
||||
// s_nop();
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
// s_nop();
|
||||
});
|
||||
if constexpr(KPerBlock > WmmaK)
|
||||
{
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<(iWmmaK + WmmaK) / A_K1>{}, Number<iCut>{}, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, Number<iCut>{}, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Stage 2: Run FIFO fashion loopover in Square
|
||||
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) {
|
||||
// Row Repeatation
|
||||
static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN) {
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
|
||||
static_for<0, WmmaK, 1>{}([&](auto iK) {
|
||||
a_thread_vec.template AsType<FloatA>()(iK) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(iK / A_K1, WmmaInnerloop + RepeatDiff, 0, 0, iK % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(iK) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(iK / B_K1, iN, 0, 0, iK % B_K1))>{}];
|
||||
});
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop + RepeatDiff, iN, 0));
|
||||
// s_nop();
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
// s_nop();
|
||||
});
|
||||
|
||||
// Col Repeatation
|
||||
static_for<WmmaInnerloop + 1 + RepeatDiff, MRepeat, 1>{}([&](auto iM) {
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
|
||||
static_for<0, WmmaK, 1>{}([&](auto iK) {
|
||||
a_thread_vec.template AsType<FloatA>()(iK) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(iK / A_K1, iM, 0, 0, iK % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(iK) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(iK / B_K1, WmmaInnerloop, 0, 0, iK % B_K1))>{}];
|
||||
});
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0));
|
||||
// s_nop();
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
// s_nop();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
protected:
|
||||
// A[M0, M1, M2, K0 = WmmaK]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<WmmaK / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}));
|
||||
|
||||
// B[N0, N1, N2, K0 = WmmaK]
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<WmmaK / B_K1>{}, Number<NRepeat>{}, I1, I1, Number<B_K1>{}));
|
||||
|
||||
// C[M, N, NumRegWMMA]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
4,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
571
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
Normal file
571
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
Normal file
@@ -0,0 +1,571 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
|
||||
{
|
||||
assert(K % K1 == 0);
|
||||
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
|
||||
}
|
||||
#ifdef ENABLE_COLMAJOR
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
|
||||
}
|
||||
#endif
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_right_pad_transform(M, PadM)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
|
||||
{
|
||||
assert(K % K1 == 0);
|
||||
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
const auto b_grid_desc_k_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_right_pad_transform(N, PadN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_k_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_m_n = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
|
||||
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumPrefetch,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t M01,
|
||||
index_t N01,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ =
|
||||
DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ =
|
||||
DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceGemmWmma_CShuffle::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
true>; // Last Option is W/O
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx1100")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<LoopScheduler, std::string> LoopSchedToString{
|
||||
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
|
||||
{PipelineVersion::v2, "v2"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmWmma_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWMMA << ", "
|
||||
<< NPerWMMA << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
<< "PipelineVersion: "
|
||||
<< PipelineVersionToString[PipelineVer];
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
641
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
Normal file
641
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
Normal file
@@ -0,0 +1,641 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_wmma(
|
||||
const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
// const
|
||||
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
|
||||
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of if (defined(__gfx1100__))
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatAcc,
|
||||
typename FloatCShuffle,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M_N,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t K1Value,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_desc_k0perblock_mperblock_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_desc_k0perblock_nperblock_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 =
|
||||
GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 =
|
||||
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size_aligned * sizeof(FloatA) +
|
||||
b_block_space_size_aligned * sizeof(FloatB));
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
|
||||
(NPerBlock % (NRepeat * NPerWmma)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
|
||||
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K0 / K0PerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / (K0PerBlock * K1);
|
||||
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
|
||||
{
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
// clang-format off
|
||||
/*******************************************************************************/
|
||||
// Memory buffer zone.
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
/*******************************************************************************/
|
||||
// BlockIdx.x -> [BlockId.m, BlockId.n]
|
||||
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{ return; }
|
||||
|
||||
// Store BlockId into SGPR
|
||||
const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
/*******************************************************************************/
|
||||
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
constexpr auto max_lds_align = K1;
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
|
||||
/* typename SrcElementwiseOperation, */ AElementwiseOperation,
|
||||
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
|
||||
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
|
||||
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
|
||||
/* typename SrcData, */ FloatA,
|
||||
/* typename DstData, */ FloatA,
|
||||
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
|
||||
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
|
||||
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
|
||||
/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
|
||||
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
|
||||
/* index_t DstVectorDim, */ 2,
|
||||
/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
|
||||
/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1,
|
||||
/* index_t SrcScalarStrideInVector, */ 1,
|
||||
/* index_t DstScalarStrideInVector, */ 1,
|
||||
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
|
||||
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
|
||||
a_grid_desc_k0_m_k1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_k0perblock_mperblock_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_grid_desc_k0_n_k1),
|
||||
decltype(b_block_desc_k0perblock_nperblock_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_k0perblock_nperblock_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
/*******************************************************************************/
|
||||
// GEMM
|
||||
constexpr auto WmmaK = 16;
|
||||
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0perblock_mperblock_k1),
|
||||
decltype(b_block_desc_k0perblock_nperblock_k1),
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
|
||||
// Prepare Register for C matrix
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
/*******************************************************************************/
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatA*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
|
||||
|
||||
// Shift Per SUB_K
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
|
||||
a_block_desc_k0perblock_mperblock_k1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_k0_n_k1,
|
||||
b_block_desc_k0perblock_nperblock_k1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
/*******************************************************************************/
|
||||
// write out to C, implement shuffle
|
||||
{
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// This API Provide All dimension (size) you need
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
|
||||
constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
|
||||
constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4);
|
||||
constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5);
|
||||
constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6);
|
||||
|
||||
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatCShuffle*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
|
||||
MWave, // MWave
|
||||
MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
|
||||
MAccVgprs)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
|
||||
NWave, // NWave
|
||||
NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatCShuffle,
|
||||
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
1, // vector write pixel
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
make_multi_index(0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
0,
|
||||
n_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
FloatCShuffle, // typename SrcData,
|
||||
FloatC, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
|
||||
c_element_op};
|
||||
|
||||
// space filling curve for local reg & global memory
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, MAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
1,
|
||||
1,
|
||||
MAccVgprs>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_c_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global.Run(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
// move on C
|
||||
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
507
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
Normal file
507
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
Normal file
@@ -0,0 +1,507 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/amd_wmma.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct WmmaInstr
|
||||
{
|
||||
wmma_f32_16x16x16_f16 = 0,
|
||||
wmma_f32_16x16x16_bf16,
|
||||
wmma_f16_16x16x16_f16,
|
||||
wmma_bf16_16x16x16_bf16,
|
||||
wmma_i32_16x16x16_iu8,
|
||||
wmma_i32_16x16x16_iu4
|
||||
};
|
||||
|
||||
/*
|
||||
* WMMA Wave Tile Always MxNxK = 16x16x16
|
||||
* WAVE32
|
||||
-----------------------------------
|
||||
|RC0| | | | | | | | | | | | | | | | SubGroup 0
|
||||
|RC1| | | | | | | | | | | | | | | |
|
||||
|RC2| | | | | | | | | | | | | | | |
|
||||
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
|
||||
|RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|
||||
|RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|
||||
|RC6| | | | | | | | | | | | | | | |
|
||||
|RC7| | | | | | | | | | | | | | | |
|
||||
-----------------------------------
|
||||
| | | | | | | | | | | | | | | | | SubGroup 1
|
||||
| | | | | | | | | | | | | | | | |
|
||||
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
|
||||
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
|
||||
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
|
||||
| | | | | | | | | | | | | | | | |
|
||||
| | | | | | | | | | | | | | | | |
|
||||
| | | | | | | | | | | | | | | | |
|
||||
-----------------------------------
|
||||
|
||||
|
||||
* WAVE64
|
||||
-----------------------------------
|
||||
|RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0
|
||||
|RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|
||||
|RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|
||||
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
|
||||
-----------------------------------
|
||||
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1
|
||||
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
|
||||
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
|
||||
| | | | | | | | | | | | | | | | |
|
||||
-----------------------------------
|
||||
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2
|
||||
| 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4|
|
||||
| 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7|
|
||||
| | | | | | | | | | | | | | | | |
|
||||
-----------------------------------
|
||||
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3
|
||||
| 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6|
|
||||
| 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3|
|
||||
| | | | | | | | | | | | | | | | |
|
||||
-----------------------------------
|
||||
|
||||
* RC = Register for storing accumalted result
|
||||
* T = Thread ID
|
||||
*/
|
||||
|
||||
template <WmmaInstr Instr, index_t WaveSize, typename = void>
|
||||
struct wmma_type
|
||||
{
|
||||
};
|
||||
|
||||
// A-swizzled
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
// * Data Pixel
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
// * num_acc_vgprs_per_wave alone M direction
|
||||
// * num_subgroups alone M direction
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
else if constexpr(wave_size == 64)
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_f16_w64<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_bf16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
else if constexpr(wave_size == 64)
|
||||
{
|
||||
intrin_wmma_f32_16x16x16_bf16_w64<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 2;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t Opsel,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
}
|
||||
else if constexpr(wave_size == 64)
|
||||
{
|
||||
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 2;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t Opsel,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
}
|
||||
else if constexpr(wave_size == 64)
|
||||
{
|
||||
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
bool neg_a,
|
||||
bool neg_b,
|
||||
bool clamp,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_i32_16x16x16_iu8_w32<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
|
||||
a, b, reg_c);
|
||||
}
|
||||
else if constexpr(wave_size == 64)
|
||||
{
|
||||
intrin_wmma_i32_16x16x16_iu8_w64<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
|
||||
a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename src_type_a,
|
||||
typename src_type_b,
|
||||
typename dst_type,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma>
|
||||
struct WmmaSelector
|
||||
{
|
||||
template <typename src_type_a_,
|
||||
typename src_type_b_,
|
||||
typename dst_type_,
|
||||
index_t MPerWmma_,
|
||||
index_t NPerWmma_>
|
||||
static constexpr auto GetWmma();
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_f32_16x16x16_f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_f32_16x16x16_bf16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_f16_16x16x16_f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_bf16_16x16x16_bf16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_i32_16x16x16_iu8;
|
||||
}
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
static constexpr auto GetWmma<int4_t, int, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_i32_16x16x16_iu4;
|
||||
}
|
||||
#endif
|
||||
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
|
||||
static constexpr auto selected_wmma =
|
||||
wmma_type<GetWmma<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>(), Number<32>{}>{};
|
||||
|
||||
__host__ __device__ constexpr WmmaSelector()
|
||||
{
|
||||
static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
|
||||
|
||||
static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
|
||||
|
||||
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
|
||||
|
||||
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
|
||||
selected_wmma.acc_data_size ==
|
||||
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
|
||||
"WRONG! Invalid Number of Accumulator Register");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename src_type_a,
|
||||
typename src_type_b,
|
||||
typename dst_type,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
struct WmmaGemm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
using CIndex = MultiIndex<2>;
|
||||
using CIndex4D = MultiIndex<4>;
|
||||
|
||||
__host__ __device__ constexpr WmmaGemm()
|
||||
{
|
||||
static_assert(NPerWmma == 16 && MPerWmma == 16,
|
||||
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
|
||||
|
||||
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
|
||||
}
|
||||
|
||||
// WMMA output supporting C = A * B
|
||||
// Vector Write
|
||||
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
|
||||
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
|
||||
const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
|
||||
{
|
||||
const auto MBlockxRepeat =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
|
||||
const auto NBlockxRepeat =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
|
||||
const auto MWave =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
|
||||
const auto NWave =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
|
||||
make_tuple(
|
||||
make_pass_through_transform(MBlockxRepeat),
|
||||
make_pass_through_transform(MWave),
|
||||
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
|
||||
Number<wmma_instr.num_acc_vgprs_per_wave>{})),
|
||||
make_pass_through_transform(NBlockxRepeat),
|
||||
make_pass_through_transform(NWave),
|
||||
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{})),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2, 6>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}));
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegSizePerWmma()
|
||||
{
|
||||
return wmma_instr.num_acc_vgprs_per_wave;
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
|
||||
{
|
||||
static_assert(
|
||||
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
|
||||
is_same<dst_type, float>::value) ||
|
||||
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value &&
|
||||
is_same<dst_type, float>::value) ||
|
||||
(is_same<src_type_a, half_t>::value && is_same<src_type_b, half_t>::value &&
|
||||
is_same<dst_type, half_t>::value) ||
|
||||
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value &&
|
||||
is_same<dst_type, bhalf_t>::value) ||
|
||||
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value &&
|
||||
is_same<dst_type, int32_t>::value)
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|| (is_same<src_type_a, int4_t>::value && is_same<src_type_b, int4_t>::value &&
|
||||
is_same<dst_type, int32_t>::value)
|
||||
#endif
|
||||
,
|
||||
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
|
||||
"(int8, int32) or (int4, int32)!");
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
|
||||
|
||||
__device__ static auto GetSubGroupId()
|
||||
{
|
||||
return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
|
||||
}
|
||||
|
||||
__device__ static auto GetLaneIdUnderSubGroup()
|
||||
{
|
||||
return GetLaneId() % wmma_instr.num_thread_per_subgroups;
|
||||
}
|
||||
__device__ static auto GetSwizzledLaneIdLow()
|
||||
{
|
||||
return ((GetLaneIdUnderSubGroup() & 1) << 3) | (GetLaneIdUnderSubGroup() >> 1);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
return GetSwizzledLaneIdLow();
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
return GetLaneIdUnderSubGroup();
|
||||
}
|
||||
|
||||
__device__ static CIndex GetBeginOfThreadBlk()
|
||||
{
|
||||
index_t n_offset = GetLaneIdUnderSubGroup();
|
||||
index_t m_offset = GetSubGroupId() * wmma_instr.num_acc_vgprs_per_wave;
|
||||
|
||||
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
|
||||
}
|
||||
|
||||
static constexpr auto wmma =
|
||||
WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
|
||||
static constexpr auto wmma_instr = wmma.selected_wmma;
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
|
||||
{
|
||||
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user