diff --git a/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt b/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt new file mode 100644 index 0000000000..d68f901fd0 --- /dev/null +++ b/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt @@ -0,0 +1,19 @@ +set(EXAMPLE_REDUCE "basic_gemm") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_REDUCE}") + +add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL gemm.cpp) +target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_REDUCE_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp new file mode 100644 index 0000000000..a0fc8ed341 --- /dev/null +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg.hpp @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "block_gemm_asmem_bsmem_creg_default_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmASmemBSmemCReg +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockWindowTmp& a_block_window_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template get<1>(); + constexpr index_t NWarp = config.template get<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WG::kM, a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array, MIterPerWarp> a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + + statically_indexed_array, NIterPerWarp> b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v, "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t{}))>; + + constexpr index_t MWarp = config.template get(number<1>{}); + constexpr index_t NWarp = config.template get(number<2>{}); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WG::kM, a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array, MIterPerWarp> a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN, b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + + statically_indexed_array, NIterPerWarp> b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + static_assert(std::is_same_v, "wrong!"); + + // Construct C-Block-Tensor + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + // warp GEMM + if constexpr(KIterPerWarp == 0) + { + // c = a * b + c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensor); + } + else + { + // c += a * b + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + + return c_block_tensor; + } +}; + +} // namespace ck diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp new file mode 100644 index 0000000000..75bfddd634 --- /dev/null +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCReg +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmASmemBSmemCRegDefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); + } + else + { + static_assert(false, "Unsupported data type configuration for GEMM warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp new file mode 100644 index 0000000000..36b7cf6215 --- /dev/null +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmPipelineAGmemBGmemCReg +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() + { + return integer_divide_ceil( + sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor().get_element_space_size(), + 16) * + 16 + + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), + 16) * 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + constexpr auto block_gemm = Policy::template GetBlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; + + // prefetch + // global read 0 + auto a_block_tile = load_tile(a_copy_dram_window); + auto b_block_tile = load_tile(b_copy_dram_window); + + { + // move to 1 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // Initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + store_tile(a_copy_lds_window, a_block_tile); + // global read 1 + a_block_tile = load_tile(a_copy_dram_window); + + // LDS write 0 + store_tile(b_copy_lds_window, b_block_tile); + // global read 1 + b_block_tile = load_tile(b_copy_dram_window); + } + + index_t iCounter = num_loop - 2; + + do + { + block_sync_lds(); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + store_tile(a_copy_lds_window, a_block_tile); + // global read i + 2 + a_block_tile = load_tile(a_copy_dram_window); + + // LDS write i + 1 + store_tile(b_copy_lds_window, b_block_tile); + // global read i + 2 + b_block_tile = load_tile(b_copy_dram_window); + + iCounter--; + + } while(iCounter > 0); + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 2 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // LDS write num_loop - 1 + store_tile(a_copy_lds_window, a_block_tile); + + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + + return c_block_tile; + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp new file mode 100644 index 0000000000..c7bbf5d552 --- /dev/null +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "block_gemm_asmem_bsmem_creg.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmPipelineAGmemBGmemCReg +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy +{ + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number<8>{}), + make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number<8>{}), + make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + + return BlockGemmASmemBSmemCReg{}; + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp new file mode 100644 index 0000000000..0c9bbd21eb --- /dev/null +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp @@ -0,0 +1,152 @@ +#include + +#include "ck_tile/host.hpp" +#include "reference_gemm.hpp" + +#include "gemm.hpp" + +/* + * Toy code of GEMM + * Assume simplest case. + * A [M, K] + * B [N, K] + * C [M, N] + */ + +// elementwise lambda +struct CElementFunction +{ + template + CK_TILE_HOST_DEVICE auto operator()(const X& x) const + { + return x; + } +}; + +int main(int argc, char* argv[]) +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + + ck_tile::index_t verification = 0; + ck_tile::index_t M = 3328; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + + if(argc == 2) + { + verification = std::stoi(argv[1]); + } + if(argc == 5) + { + verification = std::stoi(argv[1]); + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + } + + const ck_tile::index_t Lda = K; + const ck_tile::index_t Ldb = K; + const ck_tile::index_t Ldc = N; + + const auto a_lengths = std::array{M, K}; + const auto a_strides = std::array{Lda, 1}; + + const auto b_lengths = std::array{N, K}; + const auto b_strides = std::array{Ldb, 1}; + + const auto c_lengths = std::array{M, N}; + const auto c_strides = std::array{Ldc, 1}; + + // host verify + ck_tile::HostTensor a_host(a_lengths, a_strides); + ck_tile::HostTensor b_host(b_lengths, b_strides); + ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.mData.data()); + b_buf.ToDevice(b_host.mData.data()); + + // Alignment + constexpr ck_tile::index_t kAAlignment = 8; + constexpr ck_tile::index_t kBAlignment = 8; + constexpr ck_tile::index_t kCAlignment = 8; + + constexpr ck_tile::index_t kBlockSize = 256; + + constexpr ck_tile::index_t kGemmMPerBlock = 256; + constexpr ck_tile::index_t kGemmNPerBlock = 128; + constexpr ck_tile::index_t kGemmKPerBlock = 32; + + ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock); + + std::cout << "grid size " << kGridSize << std::endl; + + constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; + constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; + + using gemm_kernel = ck_tile::Gemm; + + float ave_time = + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true}, + ck_tile::make_kernel( + gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_buf.GetDeviceBuffer()), + static_cast(b_buf.GetDeviceBuffer()), + static_cast(c_buf.GetDeviceBuffer()), + M, + N, + K, + Lda, + Ldb, + Ldc, + CElementFunction{})); + auto pass = true; + + if(verification) + { + // reference gemm + ck_tile::HostTensor c_host_ref(c_lengths, c_strides); + reference_basic_gemm(a_host, b_host, c_host_ref); + c_buf.FromDevice(c_host_dev.mData.data()); + pass &= ck_tile::check_err(c_host_dev, c_host_ref); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + return !pass; +} + diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp new file mode 100644 index 0000000000..0dc6945002 --- /dev/null +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" + +#include "block_gemm_pipeline_agmem_bgmem_creg.hpp" +#include "grid_gemm.hpp" + +namespace ck_tile { + +template +struct GridGemmProblem +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CDataType = CDataType_; + + using CElementFunction = CElementFunction_; +}; + +template +struct TileGemmShape +{ + static constexpr index_t kM = kMPerTile; + static constexpr index_t kN = kNPerTile; + static constexpr index_t kK = kKPerTile; +}; + +template +struct BlockGemmPipelineProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +// C = A * B +template +struct Gemm +{ + using GridGemmProblem = GridGemmProblem; + + struct GridGemmPolicy + { + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kMPerBlock = kMPerBlock_; + static constexpr index_t kNPerBlock = kNPerBlock_; + static constexpr index_t kKPerBlock = kKPerBlock_; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t NumTilesM, + index_t NumTilesN) + { + const auto unmerge = make_merge_transform(make_tuple(NumTilesN, NumTilesM)); + + return [unmerge](index_t block_id) { + multi_index<2> unmerged; + unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); + + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline() + { + using BlockGemmPipelineProblem_ = + BlockGemmPipelineProblem>; + + return BlockGemmPipelineAGmemBGmemCReg{}; + } + }; + + using GridGemm = GridGemm; + + CK_TILE_DEVICE void operator()(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t Lda, + const index_t Ldb, + const index_t Ldc, + const CElementFunction& c_element_func) const + { + const auto a_dram = [&] { + return make_naive_tensor_view( + p_a, make_tuple(M, K), make_tuple(Lda, 1), number{}, number<1>{}); + }(); + + const auto b_dram = [&] { + return make_naive_tensor_view( + p_b, make_tuple(N, K), make_tuple(Ldb, 1), number{}, number<1>{}); + }(); + + const auto c_dram = [&] { + return make_naive_tensor_view( + p_c, make_tuple(M, N), make_tuple(Ldc, 1), number{}, number<1>{}); + }(); + + GridGemm{}(a_dram, b_dram, c_dram, c_element_func); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp new file mode 100644 index 0000000000..468c47abe1 --- /dev/null +++ b/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +template +struct GridGemm +{ + using ADataType = typename Problem::ADataType; + using BDataType = typename Problem::BDataType; + using CDataType = typename Problem::CDataType; + using CElementFunction = typename Problem::CElementFunction; + + static constexpr auto kMPerBlock = Policy::kMPerBlock; + static constexpr auto kNPerBlock = Policy::kNPerBlock; + static constexpr auto kKPerBlock = Policy::kKPerBlock; + + template + CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid, + const BGridTensorView& b_grid, + CGridTensorView& c_grid, + const CElementFunction& c_element_func) const + { + const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{}); + const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{}); + const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{}); + + // divide problem + const auto id_block = get_block_id(); + + const auto num_tile_m = M / kMPerBlock; + const auto num_tile_n = N / kNPerBlock; + + const auto block2tile = Policy::template MakeBlock2TileMap(num_tile_m, num_tile_n); + + const auto id_tile = block2tile(id_block); + + const auto iM = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) * kMPerBlock); + const auto iN = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) * kNPerBlock); + + // A block window + auto a_block_window = make_tile_window( + a_grid, make_tuple(number{}, number{}), {iM, 0}); + + // B block window + auto b_block_window = make_tile_window( + b_grid, make_tuple(number{}, number{}), {iN, 0}); + + // Block GEMM pipeline + constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline(); + + __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()]; + + const auto acc_block_tile = block_gemm_pipeline(a_block_window, + b_block_window, + K / kKPerBlock, + p_smem_char); + + // cast to CDataType and apply CElementFunction + const auto c_block_tile = tile_elementwise_in( + [&](const auto& acc) { return c_element_func(type_convert(acc)); }, + acc_block_tile); + + // store C + auto c_window = make_tile_window( + c_grid, make_tuple(number{}, number{}), {iM, iN}); + + store_tile(c_window, c_block_tile); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/02_gemm/reference_gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/reference_gemm.hpp new file mode 100644 index 0000000000..f2d0368bcb --- /dev/null +++ b/example/ck_tile/99_toy_example/02_gemm/reference_gemm.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +void reference_basic_gemm(const ck_tile::HostTensor& a_m_k, + const ck_tile::HostTensor& b_n_k, + ck_tile::HostTensor& c_m_n) +{ + const int N = b_n_k.mDesc.get_lengths()[0]; + const int K = b_n_k.mDesc.get_lengths()[1]; + + auto f = [&](auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a = a_m_k(m, k); + BDataType v_b = b_n_k(n, k); + + v_acc += ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); + } + + c_m_n(m, n) = ck_tile::type_convert(v_acc); + } + }; + + ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} diff --git a/example/ck_tile/99_toy_example/02_gemm/stream_config.hpp b/example/ck_tile/99_toy_example/02_gemm/stream_config.hpp new file mode 100644 index 0000000000..505a602b24 --- /dev/null +++ b/example/ck_tile/99_toy_example/02_gemm/stream_config.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +struct StreamConfig +{ + hipStream_t stream_id_ = nullptr; + bool time_kernel_ = false; + int log_level_ = 0; +}; diff --git a/example/ck_tile/99_toy_example/CMakeLists.txt b/example/ck_tile/99_toy_example/CMakeLists.txt index 80024d45e8..3a0bc9ad1a 100644 --- a/example/ck_tile/99_toy_example/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/CMakeLists.txt @@ -3,3 +3,4 @@ include_directories(AFTER ) add_subdirectory(01_add) +add_subdirectory(02_gemm)