diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index c403e51ed9..9b9e100edf 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -35,3 +35,8 @@ add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) + +add_custom_target(example_gemm_wmma) +add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) +add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) + diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp new file mode 100644 index 0000000000..48bcca257a --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp new file mode 100644 index 0000000000..d75f37d7b3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -0,0 +1,801 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +#define CK_MNK_LOOP + +namespace ck { + +template +/* A: K0PerBlock x MPerBlock x K1 + * B: K0PerBlock x NPerBlock x K1 + * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * KPACK == WMMA_K = 16 + */ +struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto WmmaK = Number<16>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. + static constexpr index_t WaveSize = 32; + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + // |KRepeat |MRepeat|MWave |MLane |KPack + return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + // |KRepeat |NRepeat|Nwave |NLane |KPack + return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && + NPerBlock % (NPerWMMA * NRepeat) == 0, + "wrong!"); + } + + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; + constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, + I1, + MSubGroup, + Number{}, + I1, + NThreadPerSubGroup, + MAccVgprs)); + } + + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); + static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0), + b_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[K0, M0, M1, M2, K1] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, Number{})); + + // B[K0, N0, N1, N2, K1] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, Number{})); + + // C[M, N, NumRegWMMA] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4>, + 4, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4>, + 4, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +// block wise level pipe designed for inline asm +template +/* A: K0PerBlock x MPerBlock x K1 + * B: K0PerBlock x NPerBlock x K1 + * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * KPACK == WMMA_K = 16 + */ +struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto WmmaK = Number<16>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. + static constexpr index_t WaveSize = 32; + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + // |KRepeat |MRepeat|MWave |MLane |KPack + return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + // |KRepeat |NRepeat|Nwave |NLane |KPack + return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && + NPerBlock % (NPerWMMA * NRepeat) == 0, + "wrong!"); + } + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; + constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, + I1, + MSubGroup, + Number{}, + I1, + NThreadPerSubGroup, + MAccVgprs)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); + static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + constexpr auto RepeatDiff = MRepeat - NRepeat; + // Read all Mrepeat, Nrepeat + static_for<0, NRepeat, 1>{}([&](auto iN) { + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, Number{}, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + b_thread_buf); + }); + + static_for<0, MRepeat, 1>{}([&](auto iM) { + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); + }); + + // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat + static_for<0, RepeatDiff, 1>{}([&](auto iCut) { + static_for<0, NRepeat, 1>{}([&](auto iN) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); + // s_nop(); + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + // s_nop(); + }); + if constexpr(KPerBlock > WmmaK) + { + // Read Consumed Next inner loop A + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); + } + }); + + static_for{}([&](auto iWmmaK) { + // Stage 2: Run FIFO fashion loopover in Square + static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) { + // Row Repeatation + static_for{}([&](auto iN) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); + // s_nop(); + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + // s_nop(); + }); + + // Read Consumed Next inner loop A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple( + Number{}, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); + + // Col Repeatation + static_for{}([&](auto iM) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); + // s_nop(); + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + // s_nop(); + }); + // Read Consumed Next inner loop B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, Number{}, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + b_thread_buf); + }); + + // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat + static_for<0, RepeatDiff, 1>{}([&](auto iCut) { + static_for<0, NRepeat, 1>{}([&](auto iN) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); + // s_nop(); + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + // s_nop(); + }); + if constexpr(KPerBlock > WmmaK) + { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number<(iWmmaK + WmmaK) / A_K1>{}, Number{}, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, Number{}, I0, I0, I0), + a_thread_buf); + } + }); + }); + + // Stage 2: Run FIFO fashion loopover in Square + static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) { + // Row Repeatation + static_for{}([&](auto iN) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); + // s_nop(); + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + // s_nop(); + }); + + // Col Repeatation + static_for{}([&](auto iM) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto iK) { + a_thread_vec.template AsType()(iK) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(iK) = + b_thread_buf[Number{}]; + }); + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); + // s_nop(); + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + // s_nop(); + }); + }); + } + + protected: + // A[M0, M1, M2, K0 = WmmaK] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, Number{})); + + // B[N0, N1, N2, K0 = WmmaK] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, Number{})); + + // C[M, N, NumRegWMMA] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4>, + 4, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4>, + 4, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp new file mode 100644 index 0000000000..dbcceac68f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -0,0 +1,571 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmWmma_CShuffle : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } +#ifdef ENABLE_COLMAJOR + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } +#endif + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + // Gridwise descriptor, mapping to whole given provblem. + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma< + BlockSize, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmWmma_CShuffle::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", " + << arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_wmma< + GridwiseGemm, + ADataType, + BDataType, + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; // Last Option is W/O + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_wmma< + GridwiseGemm, + ADataType, + BDataType, + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx1100") + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmWmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerWMMA << ", " + << NPerWMMA << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp new file mode 100644 index 0000000000..d70c5180da --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -0,0 +1,641 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_wmma( + const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + // const + // CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup + // c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx1100__)) +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0perblock_mperblock_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0perblock_nperblock_k1; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0perblock_mperblock_k1 = + GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0perblock_nperblock_k1 = + GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned * sizeof(FloatA) + + b_block_space_size_aligned * sizeof(FloatB)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + constexpr auto max_lds_align = K1; + constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ AElementwiseOperation, +/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ FloatA, +/* typename DstData, */ FloatA, +/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), +/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0perblock_mperblock_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatB, + FloatB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0perblock_nperblock_k1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0perblock_nperblock_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + +/*******************************************************************************/ + // GEMM + constexpr auto WmmaK = 16; + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + + auto blockwise_gemm = + BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO{}; + + // Prepare Register for C matrix + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + +/*******************************************************************************/ + constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); + // LDS allocation for A and B: be careful of alignment + auto a_block_buf = make_dynamic_buffer(static_cast(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer(static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); + + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0perblock_mperblock_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0perblock_nperblock_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // This API Provide All dimension (size) you need + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp new file mode 100644 index 0000000000..0672bf8e5b --- /dev/null +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -0,0 +1,507 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_wmma.hpp" + +namespace ck { + +enum struct WmmaInstr +{ + wmma_f32_16x16x16_f16 = 0, + wmma_f32_16x16x16_bf16, + wmma_f16_16x16x16_f16, + wmma_bf16_16x16x16_bf16, + wmma_i32_16x16x16_iu8, + wmma_i32_16x16x16_iu4 +}; + +/* + * WMMA Wave Tile Always MxNxK = 16x16x16 + * WAVE32 + ----------------------------------- + |RC0| | | | | | | | | | | | | | | | SubGroup 0 + |RC1| | | | | | | | | | | | | | | | + |RC2| | | | | | | | | | | | | | | | + |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| + |RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1| + |RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5| + |RC6| | | | | | | | | | | | | | | | + |RC7| | | | | | | | | | | | | | | | + ----------------------------------- + | | | | | | | | | | | | | | | | | SubGroup 1 + | | | | | | | | | | | | | | | | | + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| + | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3| + | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1| + | | | | | | | | | | | | | | | | | + | | | | | | | | | | | | | | | | | + | | | | | | | | | | | | | | | | | + ----------------------------------- + + + * WAVE64 + ----------------------------------- + |RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0 + |RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1| + |RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5| + |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| + ----------------------------------- + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1 + | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3| + | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1| + | | | | | | | | | | | | | | | | | + ----------------------------------- + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2 + | 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4| + | 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7| + | | | | | | | | | | | | | | | | | + ----------------------------------- + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3 + | 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6| + | 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3| + | | | | | | | | | | | | | | | | | + ----------------------------------- + +* RC = Register for storing accumalted result +* T = Thread ID +*/ + +template +struct wmma_type +{ +}; + +// A-swizzled +template +struct wmma_type> +{ + // Absolute fixing property + // * Data Pixel + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // * Fixed in Navi3x, Will be wave mode dependent on Navi4x + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + // * num_acc_vgprs_per_wave alone M direction + // * num_subgroups alone M direction + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_f16_w32::Run(a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_f32_16x16x16_f16_w64::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_bf16_w32::Run(a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_f32_16x16x16_bf16_w64::Run(a, b, reg_c); + } + } +}; + +#ifdef CK_UNPACKED_ACC_DESC_LOGIC +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 2; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_f16_16x16x16_f16_w32::Run(a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_f16_16x16x16_f16_w64::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 2; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_bf16_16x16x16_bf16_w32::Run(a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_bf16_16x16x16_bf16_w64::Run(a, b, reg_c); + } + } +}; + +#endif + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_i32_16x16x16_iu8_w32::Run( + a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_i32_16x16x16_iu8_w64::Run( + a, b, reg_c); + } + } +}; + +template +struct WmmaSelector +{ + template + static constexpr auto GetWmma(); + + template <> + static constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_f16; + } + + template <> + static constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_bf16; + } + + template <> + static constexpr auto GetWmma() + { + return WmmaInstr::wmma_f16_16x16x16_f16; + } + + template <> + static constexpr auto GetWmma() + { + return WmmaInstr::wmma_bf16_16x16x16_bf16; + } + + template <> + static constexpr auto GetWmma() + { + return WmmaInstr::wmma_i32_16x16x16_iu8; + } +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + template <> + static constexpr auto GetWmma() + { + return WmmaInstr::wmma_i32_16x16x16_iu4; + } +#endif + // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround + static constexpr auto selected_wmma = + wmma_type(), Number<32>{}>{}; + + __host__ __device__ constexpr WmmaSelector() + { + static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16"); + + static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16"); + + static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16"); + + static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * + selected_wmma.acc_data_size == + selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, + "WRONG! Invalid Number of Accumulator Register"); + } +}; + +template +struct WmmaGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __host__ __device__ constexpr WmmaGemm() + { + static_assert(NPerWmma == 16 && MPerWmma == 16, + "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"); + + static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma"); + } + + // WMMA output supporting C = A * B + // Vector Write + // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave + template + __host__ __device__ static constexpr auto + MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) + { + const auto MBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); + const auto NBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); + const auto MWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); + const auto NWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, + make_tuple( + make_pass_through_transform(MBlockxRepeat), + make_pass_through_transform(MWave), + make_unmerge_transform(make_tuple(Number{}, + Number{})), + make_pass_through_transform(NBlockxRepeat), + make_pass_through_transform(NWave), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 6>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + } + + __device__ static constexpr index_t GetRegSizePerWmma() + { + return wmma_instr.num_acc_vgprs_per_wave; + } + + __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; } + + template + __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const + { + static_assert( + (is_same::value && is_same::value && + is_same::value) || + (is_same::value && is_same::value && + is_same::value) || + (is_same::value && is_same::value && + is_same::value) || + (is_same::value && is_same::value && + is_same::value) || + (is_same::value && is_same::value && + is_same::value) +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + || (is_same::value && is_same::value && + is_same::value) +#endif + , + "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " + "(int8, int32) or (int4, int32)!"); + if constexpr(!TransposeC) + { + wmma_instr.template run(p_a_wave, p_b_wave, p_c_thread); + } + else + { + wmma_instr.template run(p_b_wave, p_a_wave, p_c_thread); + } + } + + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; } + + __device__ static auto GetSubGroupId() + { + return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups; + } + + __device__ static auto GetLaneIdUnderSubGroup() + { + return GetLaneId() % wmma_instr.num_thread_per_subgroups; + } + __device__ static auto GetSwizzledLaneIdLow() + { + return ((GetLaneIdUnderSubGroup() & 1) << 3) | (GetLaneIdUnderSubGroup() >> 1); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + return GetSwizzledLaneIdLow(); + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + return GetLaneIdUnderSubGroup(); + } + + __device__ static CIndex GetBeginOfThreadBlk() + { + index_t n_offset = GetLaneIdUnderSubGroup(); + index_t m_offset = GetSubGroupId() * wmma_instr.num_acc_vgprs_per_wave; + + return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; + } + + static constexpr auto wmma = + WmmaSelector{}; + static constexpr auto wmma_instr = wmma.selected_wmma; + + __host__ __device__ static constexpr auto + GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() + { + return make_tuple(I1, I1, Number{}); + } +}; + +} // namespace ck diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 82bf2a5eb5..4fc0be1fbd 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -355,5 +355,11 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, c3); } +// Ranged input operand +__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c) +{ + asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} + } // namespace ck #endif diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index 752876a769..a0e79220e0 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -4,11 +4,13 @@ #ifndef CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP +#include "ck/utility/amd_inline_asm.hpp" #include "data_type.hpp" // TODO: Add arch limitation namespace ck { -// wave32 only +/********************************WAVE32 MODE***********************************************/ + // src: fp16, dst: fp32 template struct intrin_wmma_f32_16x16x16_f16_w32; @@ -19,8 +21,13 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> template __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { - reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + // * Inline assembly need to elimate the duplicated data load, compiler won't help you + // delete them. + amd_assembly_wmma_f32_16x16x16_f16_w32( + reg_a, reg_b, reg_c.template AsType()(Number<0>{})); + // reg_c.template AsType()(Number<0>{}) = + // __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template + // AsType()[Number<0>{}]); } }; @@ -98,5 +105,95 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> } }; +/********************************WAVE64 MODE***********************************************/ + +template +struct intrin_wmma_f32_16x16x16_f16_w64; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w64; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); + } +}; + +// src: fp16, dst: fp16 +template +struct intrin_wmma_f16_16x16x16_f16_w64; + +template +struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); + } +}; + +// src: bf16, dst: bf16 +template +struct intrin_wmma_bf16_16x16x16_bf16_w64; + +template +struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w64; + +template +struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); + } +}; + } // namespace ck #endif diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index ef3f831abd..c70e6a407d 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -97,6 +97,7 @@ builtin_wmma_naive_selector __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) { + __shared__ src_t p_shared[16 * 16 * 2]; const int lIdx = threadIdx.x; // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and // b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the @@ -104,6 +105,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) using src_vec = typename vector_type::type; src_vec a_frag = {}; src_vec b_frag = {}; + + src_vec a_temp = {}; + src_vec b_temp = {}; // initialize c fragment to 0 using acc_vec = StaticBufferTupleOfVector; acc_vec c_thread_buf_; @@ -111,21 +115,57 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) // lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11 // see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482 // TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101 - const int lane = lIdx % 16; + const int lane = lIdx % 16; + const int lane_lo = lIdx / 2; + const int lane_hi = lIdx % 2; + for(int ele = 0; ele < 8; ++ele) + { + a_temp[ele] = a[8 * lane_hi + 16 * lane_lo + ele]; + } + + for(int ele = 0; ele < 8; ++ele) + { + b_temp[ele] = b[8 * lane_hi + 16 * lane_lo + ele]; + } + + __syncthreads(); + + for(int ele = 0; ele < 8; ++ele) + { + p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele] = a_temp[ele]; + } + + for(int ele = 0; ele < 8; ++ele) + { + p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele]; + } + + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); for(int ele = 0; ele < 16; ++ele) { - b_frag[ele] = b[16 * lane + ele]; + b_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8 + 16 * 16]; } // follow origin design for(int ele = 0; ele < 16; ++ele) { - a_frag[ele] = a[16 * lane + ele]; + a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8]; } + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); + // sync threads, similar to mma_sync - __syncthreads(); + // __syncthreads(); builtin_wmma_naive_selector(a_frag, b_frag, c_thread_buf_); + // since only fp16_fp32 asm wmma implemented for experiment purpose, restrict test case to fp16 + // when enable this ck::amd_assembly_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, + // c_thread_buf_.GetVectorTypeReference(Number<0>{}).template AsType()(Number<0>{})); __syncthreads(); // wait for results, similar to mma_sync static_for<0, 8, 1>{}([&](auto ele) {