From edf79c7064ce91f76e4dcc1a4efb3279eda2a1e4 Mon Sep 17 00:00:00 2001 From: huizzhan Date: Tue, 5 Aug 2025 04:26:23 +0000 Subject: [PATCH] basic gemm softmax --- .../ck_tile/39_gemm_softmax/CMakeLists.txt | 27 + example/ck_tile/39_gemm_softmax/README.md | 58 +++ .../block_gemm_asmem_bsmem_creg.hpp | 372 ++++++++++++++ ...k_gemm_asmem_bsmem_creg_default_policy.hpp | 100 ++++ .../block_gemm_pipeline_agmem_bgmem_creg.hpp | 464 ++++++++++++++++++ ...peline_agmem_bgmem_creg_default_policy.hpp | 352 +++++++++++++ example/ck_tile/39_gemm_softmax/config.h | 38 ++ example/ck_tile/39_gemm_softmax/gemm.hpp | 195 ++++++++ .../ck_tile/39_gemm_softmax/gemm_softmax.cpp | 202 ++++++++ example/ck_tile/39_gemm_softmax/grid_gemm.hpp | 78 +++ .../39_gemm_softmax/reference_gemm.hpp | 65 +++ .../ck_tile/39_gemm_softmax/stream_config.hpp | 14 + example/ck_tile/CMakeLists.txt | 1 + 13 files changed, 1966 insertions(+) create mode 100755 example/ck_tile/39_gemm_softmax/CMakeLists.txt create mode 100755 example/ck_tile/39_gemm_softmax/README.md create mode 100755 example/ck_tile/39_gemm_softmax/block_gemm_asmem_bsmem_creg.hpp create mode 100755 example/ck_tile/39_gemm_softmax/block_gemm_asmem_bsmem_creg_default_policy.hpp create mode 100755 example/ck_tile/39_gemm_softmax/block_gemm_pipeline_agmem_bgmem_creg.hpp create mode 100755 example/ck_tile/39_gemm_softmax/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp create mode 100755 example/ck_tile/39_gemm_softmax/config.h create mode 100755 example/ck_tile/39_gemm_softmax/gemm.hpp create mode 100755 example/ck_tile/39_gemm_softmax/gemm_softmax.cpp create mode 100755 example/ck_tile/39_gemm_softmax/grid_gemm.hpp create mode 100755 example/ck_tile/39_gemm_softmax/reference_gemm.hpp create mode 100755 example/ck_tile/39_gemm_softmax/stream_config.hpp mode change 100644 => 100755 example/ck_tile/CMakeLists.txt diff --git a/example/ck_tile/39_gemm_softmax/CMakeLists.txt b/example/ck_tile/39_gemm_softmax/CMakeLists.txt new file mode 100755 index 0000000000..b250157523 --- /dev/null +++ b/example/ck_tile/39_gemm_softmax/CMakeLists.txt @@ -0,0 +1,27 @@ +set(EXAMPLE_REDUCE "tile_example_basic_gemm_softmax") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_REDUCE}") + +add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL gemm_softmax.cpp) +target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_REDUCE_COMPILE_OPTIONS) + +# generate assembly +# list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +if(DEFINED kernel) + message("Compiling with Kernel: ${kernel}") + target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE KERNEL_${kernel}=1) +endif() + +target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/39_gemm_softmax/README.md b/example/ck_tile/39_gemm_softmax/README.md new file mode 100755 index 0000000000..3c03ffd402 --- /dev/null +++ b/example/ck_tile/39_gemm_softmax/README.md @@ -0,0 +1,58 @@ + + +# CK_TILE Toy Example + +This repository demonstrates a toy example implemented using ck_tile + +## Build Instructions + +Follow these steps to build the examples: + +```sh +cd composable_kernel +mkdir build +cd build + +cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx942" \ + -Dkernel=N .. +``` + +### Compile Examples + +#### **GEMM Softmax Example** +```sh +make -j128 tile_example_basic_gemm_softmax +``` + +## Running Examples + +### **GEMM Softmax Example** +```sh +./bin/tile_example_basic_gemm_softmax 1 4096 256 7168 +``` + +## Advanced part +#### **GEMM Example** +##### Follow these steps to build and run the different kernels: +```sh + +cd composable_kernel +mkdir build +cd build + +# for naive kernel +cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=N .. && make -j128 tile_example_basic_gemm_softmax + +# for kernel A +cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=A .. && make -j128 tile_example_basic_gemm_softmax + +# for kernel B +cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=B .. && make -j128 tile_example_basic_gemm_softmax + +... + +# for kernel H +cmake -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942" -Dkernel=H .. && make -j128 tile_example_basic_gemm_softmax \ No newline at end of file diff --git a/example/ck_tile/39_gemm_softmax/block_gemm_asmem_bsmem_creg.hpp b/example/ck_tile/39_gemm_softmax/block_gemm_asmem_bsmem_creg.hpp new file mode 100755 index 0000000000..5c58fa3d60 --- /dev/null +++ b/example/ck_tile/39_gemm_softmax/block_gemm_asmem_bsmem_creg.hpp @@ -0,0 +1,372 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "block_gemm_asmem_bsmem_creg_default_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmASmemBSmemCReg +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using WarpGemm = remove_cvref_t< + decltype(Policy::template GetWarpGemmMWarpNWarp().template get<0>())>; + static constexpr index_t MWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<1>(); + static constexpr index_t NWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<2>(); + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + +#if defined(ENABLE_PREFETCH) + // A block tile distribution for load from lds + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarp * WarpGemm::kM); + constexpr index_t KPerBlock = BlockGemmShape::kK; + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + // B block tile distribution for load from lds + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarp * WarpGemm::kN); + constexpr index_t KPerBlock = BlockGemmShape::kK; + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + static constexpr auto ALdsTileDistr = + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile aWarpTile; + BLdsTile bWarpTile; + + // Prefetch from LDS to warp register + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + aWarpTile = load_tile(a_block_window); + bWarpTile = load_tile(b_block_window); + } +#endif + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + [[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + +#if !defined(ENABLE_PREFETCH) + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // Construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // Read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; +#if defined(ENABLE_PREFETCH) +#pragma message("local data share prefetch") + a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); +#else + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); +#endif + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; +#if defined(ENABLE_PREFETCH) + b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); +#else + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); +#endif + // Read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // Warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // Write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + +#if !defined(ENABLE_PREFETCH) + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // Construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + static_assert(std::is_same_v, "wrong!"); + + // Construct C-Block-Tensor + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + // Hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // Read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; +#if defined(ENABLE_PREFETCH) + a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); +#else + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); +#endif + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; +#if defined(ENABLE_PREFETCH) + b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); +#else + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); +#endif + // Read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + // Warp GEMM + if constexpr(KIterPerWarp == 0) + { + // c = a * b + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + // c += a * b + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + + // Write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/39_gemm_softmax/block_gemm_asmem_bsmem_creg_default_policy.hpp b/example/ck_tile/39_gemm_softmax/block_gemm_asmem_bsmem_creg_default_policy.hpp new file mode 100755 index 0000000000..2fdb63794f --- /dev/null +++ b/example/ck_tile/39_gemm_softmax/block_gemm_asmem_bsmem_creg_default_policy.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +#include "config.h" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCReg +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmASmemBSmemCRegDefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { +#if defined(ADJUST_BLOCK_TILE_SHAPE) + constexpr index_t kMWarp = 2; + constexpr index_t kNWarp = 2; +#else + constexpr index_t kMWarp = 4; + constexpr index_t kNWarp = 1; +#endif + +#if defined(NAIVE_IMPLEMENTATION) +#pragma message("mfma m32 n32 k8") + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); + } +#elif defined(USING_MFMA_32x32x_8x2) +#pragma message("mfma m32 n32 k16") + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp); + } +#elif defined(USING_MFMA_16x16x16) +#pragma message("mfma m16 n16 k16") + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp); + } +#elif defined(USING_MFMA_16x16x_16x2) +#pragma message("mfma m16 n16 k32") + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp); + } +#endif + else + { + static_assert(false, "Unsupported data type configuration for GEMM warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/39_gemm_softmax/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/39_gemm_softmax/block_gemm_pipeline_agmem_bgmem_creg.hpp new file mode 100755 index 0000000000..d975befe2b --- /dev/null +++ b/example/ck_tile/39_gemm_softmax/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -0,0 +1,464 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp" + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmSoftmaxPipelineAGmemBGmemCReg +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = float; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + using BlockGemm = remove_cvref_t())>; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() + { + return integer_divide_ceil( + sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor().get_element_space_size(), + 16) * + 16 + + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + +#if defined(ENABLE_INSTRUCTION_SCH) + static constexpr index_t kPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + + static constexpr index_t GetSmemPack() { return Policy::template GetSmemPack(); } + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemm::MWarp; + constexpr index_t WaveNumN = BlockGemm::NWarp; + + constexpr index_t AB_LDS_RW_Width = GetSmemPack(); + + constexpr index_t A_Buffer_Load_Inst_Num = + kMPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + kNPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); + constexpr index_t B_LDS_Write_Inst_Num = + kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width); + + constexpr index_t C_MFMA_Inst_Num = kMPerBlock * kNPerBlock * kKPerBlock / + (kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 + ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(CDataType) / sizeof(ADataType) > + // sizeof(CDataType) / + // sizeof(BDataType) + // ? sizeof(CDataType) / + // sizeof(ADataType) : sizeof(CDataType) + // / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + constexpr auto num_mfma_per_dswrite_a = + (num_mfma_per_issue - num_dswrite_per_issue_a * 2 >= 1) ? 2 : 1; + constexpr auto num_mfma_per_dswrite_b = + (num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_a, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_a * + num_dswrite_per_issue_a, + 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_mfma_per_dswrite_b * + num_dswrite_per_issue_b, + 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } +#endif + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // ----------------------------------------------------------------------------------------- + // Definitions of all needed tiles + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + +#if defined(ENABLE_PREFETCH) + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode())); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode())); +#else + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); +#endif + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock); + + // ------------------------------------------------------------------------------------- + // Gemm pipeline start + + // Initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + +#if defined(ENABLE_PREFETCH) +#pragma message("global prefetch") + // Prefetch + // Global read 0 + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + + if(num_loop > 1) + { + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + // LDS write 0 + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 1 + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + + // Prefetch from LDS to warp register in block gemm + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + } + + __builtin_amdgcn_sched_barrier(0); + + // Main body + if(num_loop > 2) + { + index_t iCounter = 0; + do + { + block_sync_lds(); + + // LDS write 1 + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); + + // Global read 2 + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // Prefetch from LDS to warp register in block gemm + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + +#if defined(ENABLE_INSTRUCTION_SCH) + HotLoopScheduler(); +#endif + + __builtin_amdgcn_sched_barrier(0); + + iCounter += 1; + } while(iCounter < (num_loop - 2)); + } + + // Tail + if(num_loop > 1) + { + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + } + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); +#else + // non-prefetch + index_t iCounter = num_loop; + + while(iCounter > 0) + { + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + iCounter--; + } +#endif + + // apply softmax for c_block_tile + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // m_local = rowmax(c_block_tile) + auto m_local = block_tile_reduce( + c_block_tile, sequence<1>{}, f_max, std::numeric_limits::lowest()); + + block_tile_reduce_sync(m_local, f_max); + + // Pcompute{j} = sum(exp(x - m_local)) + auto p_compute = + make_static_distributed_tensor(c_block_tile.get_tile_distribution()); + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + p_compute(i_j_idx) = exp(c_block_tile[i_j_idx] - m_local[i_idx]); + }); + }); + + // rowsum for p_compute{i, j} + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, ComputeDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum); + + // softmax = p_compute{i, j} / rowsum_p + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + p_compute(i_j_idx) = p_compute[i_j_idx] / rowsum_p[i_idx]; + }); + }); + // CDramBlockWindowTmp c_dram_block_window_tmp = c_dram_block_window; + + // store_tile(c_dram_block_window_tmp, type_convert(p_compute)); + + return p_compute; + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/39_gemm_softmax/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp b/example/ck_tile/39_gemm_softmax/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp new file mode 100755 index 0000000000..30bdcd679f --- /dev/null +++ b/example/ck_tile/39_gemm_softmax/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "block_gemm_asmem_bsmem_creg.hpp" + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" + +#include "config.h" + +namespace ck_tile { + +// Default policy for BlockGemmPipelineAGmemBGmemCReg +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy +{ + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + +#if defined(NAIVE_IMPLEMENTATION) + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + +#elif defined(PADDING_K_FIRST) + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + +#elif defined(PADDING_MN_FIRST) + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + +#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE) + using ADataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); +#endif + return a_lds_block_desc; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + +#if defined(PADDING_K_FIRST) || defined(NAIVE_IMPLEMENTATION) + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + +#elif defined(PADDING_K_FIRST) + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + +#elif defined(PADDING_MN_FIRST) + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + +#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE) + using BDataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); +#endif + + return b_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + +#if defined(ENABLE_INSTRUCTION_SCH) + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize() + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + + // Assume DataType is even! + if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 && + PackedSize == 2) + { + return (PackedSize * 32 / sizeof(DataType)); + } + else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0) + { + return (PackedSize * 16 / sizeof(DataType)); + } + else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0) + { + return (PackedSize * 8 / sizeof(DataType)); + } + else if constexpr(sizeof(DataType) >= PackedSize * 4 && + XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0) + { + return (PackedSize * 4 / sizeof(DataType)); + } + else if constexpr(sizeof(DataType) >= PackedSize * 2 && + XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0) + { + return (PackedSize * 2 / sizeof(DataType)); + } + else + { + return PackedSize; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() + { + using ADataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + return GetGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() + { + using BDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + return GetGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Problem::TransposeC; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPack() + { + constexpr index_t kKPack = 8; + return kKPack; + } +#endif + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + return BlockGemmASmemBSmemCReg{}; + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/39_gemm_softmax/config.h b/example/ck_tile/39_gemm_softmax/config.h new file mode 100755 index 0000000000..235337df7b --- /dev/null +++ b/example/ck_tile/39_gemm_softmax/config.h @@ -0,0 +1,38 @@ + +#if defined(KERNEL_A) +#define PADDING_K_FIRST +#define USING_MFMA_32x32x_8x2 +#elif defined(KERNEL_B) +#define PADDING_K_FIRST +#define USING_MFMA_16x16x16 +#elif defined(KERNEL_C) +#define PADDING_K_FIRST +#define USING_MFMA_16x16x_16x2 +#elif defined(KERNEL_D) +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#elif defined(KERNEL_E) +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#define ADJUST_BLOCK_TILE_SHAPE +#elif defined(KERNEL_F) +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#define ADJUST_BLOCK_TILE_SHAPE +#define ENABLE_PREFETCH +#elif defined(KERNEL_G) +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#define ADJUST_BLOCK_TILE_SHAPE +#define ENABLE_PREFETCH +#define ENABLE_INSTRUCTION_SCH +#elif defined(KERNEL_H) +#define USING_MFMA_16x16x_16x2 +#define USING_XOR_BASED_BANK_CONFLICT_FREE +#define ADJUST_BLOCK_TILE_SHAPE +#define ENABLE_PREFETCH +#define ENABLE_INSTRUCTION_SCH +#define ENABLE_CACHE_AWARE_WG_SCH +#else +#define NAIVE_IMPLEMENTATION +#endif diff --git a/example/ck_tile/39_gemm_softmax/gemm.hpp b/example/ck_tile/39_gemm_softmax/gemm.hpp new file mode 100755 index 0000000000..0667eaba12 --- /dev/null +++ b/example/ck_tile/39_gemm_softmax/gemm.hpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +#include "block_gemm_pipeline_agmem_bgmem_creg.hpp" +#include "config.h" +#include "grid_gemm.hpp" + +namespace ck_tile { + +template +struct GridGemmProblem +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CDataType = CDataType_; + + using CElementFunction = CElementFunction_; +}; + +template +struct TileGemmShape +{ + static constexpr index_t kM = kMPerTile; + static constexpr index_t kN = kNPerTile; + static constexpr index_t kK = kKPerTile; +}; + +template +struct BlockGemmPipelineProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +// C = A * B +template +struct Gemm +{ + using GridGemmProblem = + GridGemmProblem; + + struct GridGemmPolicy + { + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kMPerBlock = kMPerBlock_; + static constexpr index_t kNPerBlock = kNPerBlock_; + static constexpr index_t kKPerBlock = kKPerBlock_; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) + { +#if defined(ENABLE_CACHE_AWARE_WG_SCH) + return [=](index_t block_1d_id) { + constexpr index_t M01 = 4; + constexpr index_t GroupNum = 8; + + const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2; + const auto update_M0 = + ((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum; + + const auto xcd_id = block_1d_id % GroupNum; + + const auto l_block_id = block_1d_id - (xcd_id % 2); + + const auto ridn = GroupNum * M01 * (update_N0 / 2); + const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn; + const auto lu = (l_block_id % GroupNum) + rid * ridn; + + const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01); + const auto sub_M0_id = + (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum; + + auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2); + auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2); + + const auto total_update_size = update_N0 * update_M0; + + if(block_1d_id >= total_update_size) + { + auto x = (block_1d_id + 1) - total_update_size; + auto rlen = N0 - update_N0; + + auto rm = 0; + auto rn = 0; + if(rlen > 0) + { + rm = (x - 1) / rlen; + rn = x % rlen; + } + + if(rlen > 0 and rm < M0) + { + n = rn + update_N0; + m = rm; + } + else + { + x = x - rlen * M0; + rm = (x - 1) / update_N0; + rn = x % update_N0; + n = rn; + m = update_M0 + rm; + } + } + return make_multi_index(m, n); + }; +#else + const auto unmerge = make_merge_transform(make_tuple(N0, M0)); + + return [unmerge](index_t block_id) { + multi_index<2> unmerged; + unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); + + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + }; +#endif + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline() + { + using BlockGemmPipelineProblem_ = + BlockGemmPipelineProblem>; + return BlockGemmSoftmaxPipelineAGmemBGmemCReg{}; + } + }; + + using GridGemm = GridGemm; + + CK_TILE_DEVICE void operator()(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t Lda, + const index_t Ldb, + const index_t Ldc, + const CElementFunction& c_element_func) const + { + const auto a_dram = [&] { + return make_naive_tensor_view( + p_a, make_tuple(M, K), make_tuple(Lda, 1), number{}, number<1>{}); + }(); + + const auto b_dram = [&] { + return make_naive_tensor_view( + p_b, make_tuple(N, K), make_tuple(Ldb, 1), number{}, number<1>{}); + }(); + + const auto c_dram = [&] { + return make_naive_tensor_view( + p_c, make_tuple(M, N), make_tuple(Ldc, 1), number{}, number<1>{}); + }(); + + GridGemm{}(a_dram, b_dram, c_dram, c_element_func); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/39_gemm_softmax/gemm_softmax.cpp b/example/ck_tile/39_gemm_softmax/gemm_softmax.cpp new file mode 100755 index 0000000000..4215ad107c --- /dev/null +++ b/example/ck_tile/39_gemm_softmax/gemm_softmax.cpp @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "config.h" +#include "ck_tile/host.hpp" +#include "gemm.hpp" +#include "reference_gemm.hpp" + +/* + * Toy code of GEMM + * Assume simplest case. + * A [M, K] + * B [N, K] + * C [M, N] + */ + +// elementwise lambda +struct CElementFunction +{ + template + CK_TILE_HOST_DEVICE auto operator()(const X& x) const + { + return x; + } +}; + +int main(int argc, char* argv[]) +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + + ck_tile::index_t verification = 0; + ck_tile::index_t M = 3328; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + + if(argc == 2) + { + verification = std::stoi(argv[1]); + } + if(argc == 5) + { + verification = std::stoi(argv[1]); + M = std::stoi(argv[2]); + N = std::stoi(argv[3]); + K = std::stoi(argv[4]); + } + +#if defined(KERNEL_A) + printf("*** Kernel A test *** \n"); + printf(" --> Using mfma_32x32x(8x2)\n"); +#elif defined(KERNEL_B) + printf("*** Kernel B test *** \n"); + printf(" --> Using mfma_16x16x16\n"); +#elif defined(KERNEL_C) + printf("*** Kernel C test *** \n"); + printf(" --> Using mfma_16x16x(16x2)\n"); +#elif defined(KERNEL_D) + printf("*** Kernel D test *** \n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based bank-conflict-free\n"); +#elif defined(KERNEL_E) + printf("*** Kernel E test ***\n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based bank-conflict-free\n"); + printf(" --> Adjust block tile shape\n"); +#elif defined(KERNEL_F) + printf("*** Kernel F test ***\n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based bank-conflict-free\n"); + printf(" --> Adjust block tile shape\n"); + printf(" --> Enable prefetch\n"); +#elif defined(KERNEL_G) + printf("*** Kernel G test ***\n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based bank-conflict-free\n"); + printf(" --> Adjust block tile shape\n"); + printf(" --> Enable prefetch\n"); + printf(" --> Enable instruction schedule\n"); +#elif defined(KERNEL_H) + printf("*** Kernel H test ***\n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based bank-conflict-free\n"); + printf(" --> Adjust block tile shape\n"); + printf(" --> Enable prefetch\n"); + printf(" --> Enable instruction schedule\n"); + printf(" --> Enable cache-aware thread blocks schedule\n"); +#else + printf("*** Naive implementation test ***\n"); +#endif + + const ck_tile::index_t Lda = K; + const ck_tile::index_t Ldb = K; + const ck_tile::index_t Ldc = N; + + const auto a_lengths = std::array{M, K}; + const auto a_strides = std::array{Lda, 1}; + + const auto b_lengths = std::array{N, K}; + const auto b_strides = std::array{Ldb, 1}; + + const auto c_lengths = std::array{M, N}; + const auto c_strides = std::array{Ldc, 1}; + + // host verify + ck_tile::HostTensor a_host(a_lengths, a_strides); + ck_tile::HostTensor b_host(b_lengths, b_strides); + ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.mData.data()); + b_buf.ToDevice(b_host.mData.data()); + + // Alignment + constexpr ck_tile::index_t kAAlignment = 8; + constexpr ck_tile::index_t kBAlignment = 8; + constexpr ck_tile::index_t kCAlignment = 8; + + constexpr ck_tile::index_t kBlockSize = 256; + +#ifdef ADJUST_BLOCK_TILE_SHAPE + constexpr ck_tile::index_t kGemmMPerBlock = 128; + constexpr ck_tile::index_t kGemmKPerBlock = 64; +#else + constexpr ck_tile::index_t kGemmMPerBlock = 128; + constexpr ck_tile::index_t kGemmKPerBlock = 16; +#endif + constexpr ck_tile::index_t kGemmNPerBlock = 256; + + ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock); + + std::cout << "grid size " << kGridSize << std::endl; + + constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; + constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; + + using gemm_kernel = ck_tile::Gemm; + + float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 5, 1000}, + ck_tile::make_kernel( + gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_buf.GetDeviceBuffer()), + static_cast(b_buf.GetDeviceBuffer()), + static_cast(c_buf.GetDeviceBuffer()), + M, + N, + K, + Lda, + Ldb, + Ldc, + CElementFunction{})); + auto pass = true; + + if(verification) + { + // reference gemm + ck_tile::HostTensor c_host_ref(c_lengths, c_strides); + reference_basic_gemm_softmax( + a_host, b_host, c_host_ref); + c_buf.FromDevice(c_host_dev.mData.data()); + pass &= ck_tile::check_err(c_host_dev, c_host_ref); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + return !pass; +} diff --git a/example/ck_tile/39_gemm_softmax/grid_gemm.hpp b/example/ck_tile/39_gemm_softmax/grid_gemm.hpp new file mode 100755 index 0000000000..7946da263e --- /dev/null +++ b/example/ck_tile/39_gemm_softmax/grid_gemm.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +template +struct GridGemm +{ + using ADataType = typename Problem::ADataType; + using BDataType = typename Problem::BDataType; + using CDataType = typename Problem::CDataType; + using AccDataType = typename Problem::AccDataType; + using ComputeDataType = float; + using CElementFunction = typename Problem::CElementFunction; + + static constexpr auto kMPerBlock = Policy::kMPerBlock; + static constexpr auto kNPerBlock = Policy::kNPerBlock; + static constexpr auto kKPerBlock = Policy::kKPerBlock; + + template + CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid, + const BGridTensorView& b_grid, + CGridTensorView& c_grid, + const CElementFunction& c_element_func) const + { + const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{}); + const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{}); + const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{}); + + // divide problem + const auto id_block = get_block_id(); + + const auto num_tile_m = integer_divide_ceil(M, kMPerBlock); + const auto num_tile_n = integer_divide_ceil(N, kNPerBlock); + + const auto block2tile = Policy::template MakeBlock2TileMap(num_tile_m, num_tile_n); + + const auto id_tile = block2tile(id_block); + + const auto iM = + __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) * kMPerBlock); + const auto iN = + __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) * kNPerBlock); + + // A block window + auto a_block_window = make_tile_window( + a_grid, make_tuple(number{}, number{}), {iM, 0}); + + // B block window + auto b_block_window = make_tile_window( + b_grid, make_tuple(number{}, number{}), {iN, 0}); + + constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline(); + + __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()]; + + // store C + auto c_window = make_tile_window( + c_grid, make_tuple(number{}, number{}), {iM, iN}); + + // block_gemm_pipeline(a_block_window, b_block_window, c_window, K / kKPerBlock, p_smem_char, c_element_func); + + const auto acc_block_tile = + block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char); + + // cast to CDataType and apply CElementFunction + const auto c_block_tile = tile_elementwise_in( + [&](const auto& acc) { return c_element_func(type_convert(acc)); }, + acc_block_tile); + + + store_tile(c_window, c_block_tile); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/39_gemm_softmax/reference_gemm.hpp b/example/ck_tile/39_gemm_softmax/reference_gemm.hpp new file mode 100755 index 0000000000..0c6bbecfc1 --- /dev/null +++ b/example/ck_tile/39_gemm_softmax/reference_gemm.hpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +void reference_basic_gemm_softmax(const ck_tile::HostTensor& a_m_k, + const ck_tile::HostTensor& b_n_k, + ck_tile::HostTensor& c_m_n) +{ + const int N = b_n_k.mDesc.get_lengths()[0]; + const int K = b_n_k.mDesc.get_lengths()[1]; + + auto f = [&](auto m) { + for(int n = 0; n < N; ++n) + { + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a = a_m_k(m, k); + BDataType v_b = b_n_k(n, k); + + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); + } + + c_m_n(m, n) = ck_tile::type_convert(v_acc); + } + // reference softmax + AccDataType v_max = std::numeric_limits::lowest(); + + // max + for(int n = 0; n < N; ++n) + { + const AccDataType v_c = c_m_n(m, n); + + v_max = v_max < v_c ? v_c : v_max; + } + + AccDataType v_exp_sum = 0; + + // sum + for(int n = 0; n < N; ++n) + { + const AccDataType v_c = c_m_n(m, n); + + v_exp_sum += ck_tile::exp(v_c - v_max); + } + + // elementwise + for(int n = 0; n < N; ++n) + { + const AccDataType v_c = c_m_n(m, n); + + c_m_n(m, n) = ck_tile::exp(v_c - v_max) / v_exp_sum; + } + }; + + ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])( + std::thread::hardware_concurrency()); +} diff --git a/example/ck_tile/39_gemm_softmax/stream_config.hpp b/example/ck_tile/39_gemm_softmax/stream_config.hpp new file mode 100755 index 0000000000..505a602b24 --- /dev/null +++ b/example/ck_tile/39_gemm_softmax/stream_config.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +struct StreamConfig +{ + hipStream_t stream_id_ = nullptr; + bool time_kernel_ = false; + int log_level_ = 0; +}; diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt old mode 100644 new mode 100755 index 630b96ede0..479bebdbe3 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -23,3 +23,4 @@ add_subdirectory(20_grouped_convolution) add_subdirectory(21_elementwise) add_subdirectory(35_batched_transpose) add_subdirectory(38_block_scale_gemm) +add_subdirectory(39_gemm_softmax) \ No newline at end of file