From 2312f4aaf93d4d31d8be205a3aa66dd962ea1bda Mon Sep 17 00:00:00 2001 From: aledudek Date: Tue, 11 Feb 2025 09:49:48 +0100 Subject: [PATCH 1/7] [CK_TILE] Add GetName for GEMM kernels (#1791) * [CK_TILE] Add GetName functions for Gemm Kernels * [CK_TILE] Add GetName for grouped gemm * [CK_TILE] Add GetName for gemm - review changes * [CK_TILE] Print also gemm problem pipeline and shape * [CK_TILE] Print also GemmPipelineScheduler * [CK_TILE] GetName - fixed Scheduler < float conversions * Add scaled conversions with tests * Add device conversions * Make sure all tests and examples are built for gfx950 * Facilitate testing of FP8 data types on the emulator * Introduce two new tensor generators * Enable instances built for gfx94 to be built on gfx950 * Verify 35_splitk_gemm on floating point numbers. splitk gemm appears to be losing precision VS reference implementation when FP numbers are involved. * Format * Verify 04_gemm_add_add_fastgelu on floating point numbers * Verify 20_grouped_conv_bwd_weight on floating point numbers * Verify 38_grouped_conv_bwd_data_multiple_d on floating point numbers * Verify more tests on floating point data * Fix data types and improve testing verbocity. * Add fp4 vectors * Add debug tests * Upgrade to NPI 573 build docker. * Skip on gemm_universal tests. The tests take too long to complete on the emulator. Need to see if it is possible to reduce the scope of the testing to just FP8 data types. * Add new mfma instructions and examples * Add preprocessor directives for gfx950 specific code * Fix gfx1101 build * Document test availability * Re-enable fp8 gemms for gfx94/95 * Cherry-pick GEMM Universal tests for FP8 data types * Cleanup * Add vector types and tests * Add check_err function * Add tensor generators * CK_USE_GFX94 has already been set on this branch * Fix * Address formatting issues and leftovers * Make fail/pass logic consistent within 01_gemm folder Removed multiple negations in fail/pass logic to propagate `true` as the success indicator. * Fix GPU verification reporting logic. * Update year in copyright notice. * Cleanup * Use `enum class` instead of `enum` * Remove set_property for FP8 tests * Add vector conversions * Fix * Fix linker errror * Clean up * Fix gfx950 conversions * Clean up * Fix more gfx950 conversions * Fix even more gfx950 conversions * Narrowing the scope of PR to OCP FP8 enablement only * Add tests for OCP FP8 vector_type storage * Fix client examples build * Fix typo * Update e8m0 casting * Rename E8M0 type * Update unpack method * Cleanup merge artifacts * Enable gemm kernel on all gfx9 architectures (#227) * clean-up * Implement `non_native_vector_base` with `ext_vector_type` array. (#232) * Enable support of 1, 2, 4, and 8-byte custom types in CK. * Fix pool tests for OCP FP8 data type * Fix build * Add ckProfiler gemm instances for new mfma instructions and fix ckProfiler build on MI350 * fix clang format * Add new mfma instructions and examples * Add preprocessor directives for gfx950 specific code * Add ckProfiler gemm instances for new mfma instructions and fix ckProfiler build on MI350 * fix clang format * Fix clang format for the newly merged files * Use the existing example instances for fp16 bf16 and int8 * Remove comment on new mfma instructions in MfmaInstr * Update include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * merge from public repo * Fix ck build * Fix ck build * Use double for max_abs_in_val * Move scaled_type_convert functions to a separate header (#251) * re-enable building mha lib and gemm_universal_f8 instances for gfx950 * Update library/src/tensor_operation_instance/gpu/CMakeLists.txt Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * fix typo for CK_USE_OCP_FP8 * fix typo for CK_USE_OCP_FP8 * Add FP6 and BF6 types (#261) * Add a rounding flag * Add FP6 and BF6 * Add tests Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * Clean up --------- Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * fix one more typo * Refactor E8M0 scale implementation (#262) * Refactor E8M0 scale implementation * Add MXFP6 and MXBF6 conversion methods (#270) * Add conversions * Add tests * Add docstrings * Add scaled conversions * Add fp6/bf6 tests * Remove misleading fp4 test case * Add docstrings * Clean up * Address comments * Set stricter tolerances for RNE tests * Add missing tests * Add native conversions to float * Revert "Add native conversions to float" This reverts commit 09467111f73b753c8cc3d597533b187940353dab. * Update copyright years * replace the fp6 with bf6 convert calls in test_bf6 * fix test_bf6 * enable smfmac test * [MX FP8] Add Scaled Type Convert Functions for OCP FP8/BF8 data types (#271) * Move scaled_type_convert functions to a separate header * Introduce MX data tests * Build MX tests only on relevant architectures * Refactor E8M0 scale implementation * Fix `config.h` typo * Cleanup deprecated symbols * Refactor `amd_ck_fp8.hpp` * `scaled_type_convert` for `f8_ocp_t` * Implement test for MX FP8 scaled type convert * Implement test for MX BF8 scaled type convert * Scaled type convert for vectors of 2 FP8 elements * Scaled type convert for vectors of 16 FP8 elements * Implementation of scaled conversion from F32 to F8 * Add tests for scaled conversions from FP32 to FP8 * Add documentation to the test functions * Implementation of scaled conversion from F32x2 to F8x2 * Implementation of scaled conversion from F32x16 to F8x16 * Implementation of scaled conversion from F32x32 to F8x32 * Implementation of scaled conversion from F8x32 to F32x32 * Verified on the emulator * MX FP GEMM - Example Template (#277) Temporarily uses `DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3` kernel and 128x128 scaling matrices. Must be modified to use MX-native GEMM kernell with 16 or 32 component vectors per scale. Verified on the emulator. * Add vector support * Add tests * Add missing type aliases * Fix test naming * only build mx example for gfx950 * disable CK_USE_AMD_MFMA_GFX950 by default * fic build for multiple archs * fix typo * fix typo * Update unpack signature * Fix merge * Add size checks in pack function * Add a flag * Add conversions * Fix build logic * Update pack/unpack methods * Remove unneeded AsType accessors * Add docstrings * Add a flag to config file * Test the functionality of V_MFMA_F32_16X16X128_F8F6F4 and V_MFMA_F32_32X32X64_F8F6F4 instructions. (#293) * Introduced MFMA tests * Verified f8f6f4 MFMA Instructions * Move flag logic to scaled_type_convert header * Use pointers instead of array indices * Fix a typo * Update tests and pack functions * Fix gemm gemm on gfx950 * Fix clang format * restore the default gput target lists * fix the jenkinsfile * add missing ifdef --------- Co-authored-by: Jing Zhang Co-authored-by: aska-0096 Co-authored-by: Jun Liu Co-authored-by: Andriy Roshchenko Co-authored-by: Rostyslav Geyyer Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: root Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Co-authored-by: jefyang1 <146495389+jefyang1@users.noreply.github.com> Co-authored-by: jefyang1 * restore cron trigger (#1863) * add vectorloads on non-k dim for memory pipelines (#1856) * Support for dtypes (fp8, bf8, bf16 and fp16) for the ck_tile/03_gemm example. (#1845) * Support bf16/fb8/bf8 datatypes for ck_tile/gemm * remove commented out code. * Addressing code review comments and enabling universal_gemm for all the supported data types. * Merge conflict resolution. * Solve the memory pipeline compilation error. Merge with the new change of CShuffle * finish the feature, pass the tests * Fix the pipeline and add the benchmark script for other data types --------- Co-authored-by: ThomasNing * Extract prec_str and add separator to concat * GetName add * CK Tile - small fix to hotloop scheduler & KPack value. (#1867) * Use SmemPack in HotLoop scheduler * Additional debug print information * Change KPack value. Hardcode for now, as without AK1/BK1 there's no good way to determine its value. * Fix HotLoopScheduler MFMA instr parameters. * Resolve merge issues --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: Jing Zhang Co-authored-by: aska-0096 Co-authored-by: Jun Liu Co-authored-by: Andriy Roshchenko Co-authored-by: Rostyslav Geyyer Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: root Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Co-authored-by: jefyang1 <146495389+jefyang1@users.noreply.github.com> Co-authored-by: jefyang1 Co-authored-by: jakpiase Co-authored-by: kylasa Co-authored-by: ThomasNing Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- example/ck_tile/03_gemm/gemm_basic.cpp | 7 +- example/ck_tile/03_gemm/gemm_basic.hpp | 2 +- example/ck_tile/03_gemm/run_gemm_example.inc | 4 +- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 7 +- .../run_batched_gemm_example.inc | 2 +- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 2 +- .../run_grouped_gemm_example.inc | 2 +- include/ck_tile/core.hpp | 2 +- include/ck_tile/host.hpp | 1 + include/ck_tile/host/concat.hpp | 122 ++++++++++++++++++ include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 1 + include/ck_tile/ops/batched_transpose.hpp | 1 + include/ck_tile/ops/common.hpp | 1 + include/ck_tile/ops/common/utils.hpp | 34 +++++ include/ck_tile/ops/elementwise.hpp | 1 + include/ck_tile/ops/epilogue.hpp | 1 + include/ck_tile/ops/flatmm.hpp | 1 + include/ck_tile/ops/fmha.hpp | 1 + include/ck_tile/ops/fused_moe.hpp | 1 + include/ck_tile/ops/gemm.hpp | 1 + .../ops/gemm/kernel/batched_gemm_kernel.hpp | 16 ++- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 8 ++ .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 12 ++ .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 10 ++ .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 11 ++ .../gemm_pipeline_ag_bg_cr_scheduler.hpp | 3 +- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 18 ++- .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 8 ++ .../gemm/pipeline/gemm_pipeline_problem.hpp | 15 ++- .../ops/gemm/pipeline/tile_gemm_shape.hpp | 13 +- include/ck_tile/ops/image_to_column.hpp | 1 + include/ck_tile/ops/layernorm2d.hpp | 1 + include/ck_tile/ops/norm_reduce.hpp | 1 + include/ck_tile/ops/permute.hpp | 1 + include/ck_tile/ops/reduce.hpp | 1 + include/ck_tile/ops/rmsnorm2d.hpp | 1 + include/ck_tile/ops/smoothquant.hpp | 1 + include/ck_tile/ops/softmax.hpp | 1 + include/ck_tile/ops/topk.hpp | 1 + include/ck_tile/ops/topk_softmax.hpp | 1 + 40 files changed, 300 insertions(+), 18 deletions(-) create mode 100644 include/ck_tile/host/concat.hpp create mode 100644 include/ck_tile/ops/common/utils.hpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 2e04780eb0..5dc7b9cd0b 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -82,8 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 5fa94f5f72..ed02f89fac 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 028f8a44c3..5746aa2b7b 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -171,7 +171,7 @@ int run_gemm_example_with_layouts(int argc, std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; - std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } else if(arg_parser.get_int("v") == 2) { @@ -229,7 +229,7 @@ int run_gemm_example_with_layouts(int argc, std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; - std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; } return pass; diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 949621e116..286fe4201d 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -79,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index d0df8845cc..1105304e3e 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc, << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; - std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } else if(arg_parser.get_int("v") == 2) { diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index c32fac6c0d..03d5818179 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -118,7 +118,7 @@ float grouped_gemm(const std::vector& gemm_descs, if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" + std::cout << "Launching kernel: " << GroupedGemmKernel::GetName() << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index b0a3e9973c..080ea818c9 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -202,7 +202,7 @@ int run_grouped_gemm_example_with_layouts(int argc, << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; } - std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } return pass; diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index ba4f4b6e7d..a8c95b9c38 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -27,12 +27,12 @@ #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/int8.hpp" -#include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/tensor/buffer_view.hpp" diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 39a904717c..5a5e01460f 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -5,6 +5,7 @@ #include "ck_tile/host/arg_parser.hpp" #include "ck_tile/host/check_err.hpp" +#include "ck_tile/host/concat.hpp" #include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp" #include "ck_tile/host/convolution_parameter.hpp" #include "ck_tile/host/device_memory.hpp" diff --git a/include/ck_tile/host/concat.hpp b/include/ck_tile/host/concat.hpp new file mode 100644 index 0000000000..c68b908149 --- /dev/null +++ b/include/ck_tile/host/concat.hpp @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" + +namespace ck_tile { + +template +struct IsCharArray : std::false_type +{ +}; + +template +struct IsCharArray : std::true_type +{ +}; + +template +struct IsCharArray : std::true_type +{ +}; + +template +struct IsCharArray : std::true_type +{ +}; + +template +struct IsCharArray : std::true_type +{ +}; + +template +inline constexpr bool AllConvertibleToStringView = ((std::is_convertible_v || + IsCharArray::value || + std::is_same_v)&&...); + +template +[[nodiscard]] auto concat(const Ts&... xs) + -> std::enable_if_t, std::string> +{ + using ::operator<<; + thread_local std::ostringstream oss; + oss.str(""); + + (oss << ... << xs); + return oss.str(); +} + +template +[[nodiscard]] constexpr inline std::size_t getSize(char (&)[N]) noexcept +{ + return N; +} + +template +[[nodiscard]] constexpr inline std::size_t getSize(const char (&)[N]) noexcept +{ + return N; +} + +[[nodiscard]] constexpr inline std::size_t getSize(const char* s) noexcept +{ + const char* end = s; + while(*end++ != 0) {} + return end - s - 1; +} + +[[nodiscard]] constexpr inline std::size_t getSize(const char&) noexcept { return 1; } + +[[nodiscard]] inline std::size_t getSize(const std::string& s) noexcept { return s.size(); } + +[[nodiscard]] constexpr inline std::size_t getSize(const std::string_view& s) noexcept +{ + return s.size(); +} + +template +auto concatInto(std::string& result, const Ts&... xs) + -> std::enable_if_t, void> +{ + const std::size_t space = (1 + ... + getSize(xs)); + result.reserve(result.size() + space); + ((result += xs), ...); +} + +template +[[nodiscard]] auto concat(const Ts&... xs) + -> std::enable_if_t, std::string> +{ + std::string result; + concatInto(result, xs...); + return result; +} + +// Function for types convertible to std::string_view +template +[[nodiscard]] auto concat(Sep sep, const First& first, const Rest&... rest) + -> std::enable_if_t, std::string> +{ + std::string result; + result += first; + ((result += sep, result += rest), ...); + return result; +} + +// Function for other types +template +[[nodiscard]] auto concat(Sep sep, const First& first, const Rest&... rest) + -> std::enable_if_t, std::string> +{ + using ::operator<<; + thread_local std::ostringstream oss; + oss.str(""); + oss << first; + ((oss << sep << rest), ...); + return oss.str(); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 8b5302257c..1768c802d5 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -10,3 +10,4 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index ade2f18041..200e2a618c 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -9,3 +9,4 @@ #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 9b9bf30ad3..027e2fdd94 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -5,3 +5,4 @@ #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp new file mode 100644 index 0000000000..8592f93e0f --- /dev/null +++ b/include/ck_tile/ops/common/utils.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// clang-format off +template struct typeToStr; +template <> struct typeToStr { static constexpr const char * name = "fp32"; }; +template <> struct typeToStr { static constexpr const char * name = "fp16"; }; +template <> struct typeToStr { static constexpr const char * name = "bf16"; }; +template <> struct typeToStr { static constexpr const char * name = "fp8"; }; +template <> struct typeToStr { static constexpr const char * name = "bf8"; }; +template <> struct typeToStr { static constexpr const char * name = "int8"; }; +// clang-format on + +template +std::string gemm_prec_str() +{ + std::string base_str = std::string(typeToStr::name); + if(!std::is_same_v) + { + base_str += "_" + std::string(typeToStr::name); + } + return base_str; +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 15fa269740..53187771b9 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -6,3 +6,4 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 95ead2645e..9d2ed407c9 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -8,3 +8,4 @@ #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 616db2fa5b..82f6d48eda 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -9,3 +9,4 @@ #include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 4cbb59e95b..c896534e03 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -44,3 +44,4 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index d2d328fc46..3ffb0a9ca2 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -17,3 +17,4 @@ #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 5bbe0601b7..a94628a59a 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -46,3 +46,4 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 0f8bec3cf4..323c682f2c 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -1,9 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel, + concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); + // clang-format on + } + struct BatchedGemmKernelArgs : GemmKernelArgs { index_t batch_stride_A; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index aa31d1fccf..4ed3006c89 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -75,6 +76,13 @@ struct GemmKernel static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str, GemmPipeline::GetName()); + // clang-format on + } + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) { return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 13d3df02f9..751e7c0e1a 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -64,6 +64,18 @@ struct GroupedGemmKernel : public GemmKernel, + concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); + // clang-format on + } + __host__ static auto GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 0a40ca359e..eec3886e2f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -10,6 +10,7 @@ #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -81,6 +82,15 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 using Base::PrefetchStages; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AgBgCrCompV3", BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Policy::template GetSmemSize(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index e23f0cda7d..f8dd2348cb 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -128,6 +129,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AgBgCrMe", + concat('x', MPerBlock, NPerBlock, KPerBlock), + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + using Base::PrefetchStages; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index 6f51e6b8a9..b18bf603a9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include +#include #include "ck_tile/core.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index d9f04a87c3..a2a14d1017 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -39,6 +40,18 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; + static constexpr index_t kLdsAlignmentInBytes = 16; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AGmemBGmemCRegV1", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -75,8 +88,9 @@ struct GemmPipelineAGmemBGmemCRegV1 auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * - 16; + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), + kLdsAlignmentInBytes) * + kLdsAlignmentInBytes; // B tile in LDS BDataType* p_b_lds = static_cast( diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index 0417035fb6..ce2dc9fb96 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -25,6 +26,13 @@ struct GemmPipelineAGmemBGmemCRegV2 static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AGmemBGmemCRegV2", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, kBlockSize)); + // clang-format on + } CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index a69f72626c..dd631876b4 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -35,9 +36,19 @@ struct GemmPipelineProblemBase static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadK = Traits::kPadK; - static constexpr auto Scheduler = GemmPipelineScheduler::Default; - + static constexpr auto Scheduler = GemmPipelineScheduler::Default; static constexpr index_t VectorLoadSize = Traits::_VectorSize; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm_problem", + concat('x', VectorLoadSize, kBlockSize), + concat('x', kPadM, kPadN, kPadK), + Scheduler); + // clang-format on + } + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() { if constexpr(std::is_same_v) diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 2522abe5ed..24a399f18d 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -19,6 +20,16 @@ struct TileGemmShape static constexpr index_t kM = BlockTile::at(number<0>{}); static constexpr index_t kN = BlockTile::at(number<1>{}); static constexpr index_t kK = BlockTile::at(number<2>{}); + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + return concat('_', "tile_gemm_shape", + concat('x', kM, kN, kK, NumWarps), + concat('x', BlockWarps::at(number<0>{}), BlockWarps::at(number<1>{}), BlockWarps::at(number<2>{})), + concat('x', (WarpTile::at(number<0>{})), WarpTile::at(number<1>{}), WarpTile::at(number<2>{}))); + // clang-format on + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index d54b7f60d6..93664ea138 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -8,3 +8,4 @@ #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index 47d986e1c2..afbb817db1 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -11,3 +11,4 @@ #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index 9392f8b439..7dc3e8b7e7 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -8,3 +8,4 @@ #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index f3abe84e46..1cc3d9cbc3 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -7,3 +7,4 @@ #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index b817d09c72..80ead84e85 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -9,3 +9,4 @@ #include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 73fd6bfb0e..3eec2a1ab6 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -11,3 +11,4 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index 3fe1b5b213..dc164dc1a0 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -11,3 +11,4 @@ #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index 391609622a..b23e869d81 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -7,3 +7,4 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 40b9edd72f..1dc563f757 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -7,3 +7,4 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index efc1d17637..d0a810de4f 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -9,3 +9,4 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" From c0adab485020b83f324d2efdcac2c997e19443eb Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 11 Feb 2025 17:49:17 +0800 Subject: [PATCH 2/7] [CK_TILE] moe sorting ex kernel to support expert > 128 (#1840) * moe sorting ex * fix bug for race condition * fix bug and optimze large expert * fix * optimize with sub_token_oneshot * support skip empty tokens for expert sorting * update moe_sorting * tidy code --- .../ck_tile/13_moe_sorting/moe_sorting.cpp | 63 +- .../13_moe_sorting/moe_sorting_api.cpp | 82 +++ .../13_moe_sorting/moe_sorting_api.hpp | 3 +- .../13_moe_sorting/script/smoke_test.sh | 8 + example/ck_tile/15_fused_moe/README.md | 2 +- .../instances/fused_moesorting_api.cpp | 74 ++ .../host/reference/reference_moe_sorting.hpp | 26 +- include/ck_tile/ops/fused_moe.hpp | 2 +- .../fused_moe/kernel/fused_moegemm_kernel.hpp | 2 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 693 ++++++++++++++++-- .../fused_moe/kernel/moe_sorting_problem.hpp | 52 ++ .../pipeline/moe_sorting_problem.hpp | 28 - 12 files changed, 936 insertions(+), 99 deletions(-) create mode 100644 include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp delete mode 100644 include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index d2c4df1058..c4faa35e33 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[]) .insert("k", "4", "topk") .insert("unit", "32", "unit_size") .insert("moe_buf_size", "0", "moe_buf_size") + .insert("local_eid", + "-1", + "a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n" + "please make sure eid is in ascending order!") .insert("seed", "-1", "seed to be used, -1 means random every time") .insert("kname", "0", "when set to 1 it will print kernel name") .insert("warmup", "5", "number of iterations before benchmark the kernel") @@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) int kname = args.get_int("kname"); int warmup = args.get_int("warmup"); int repeat = args.get_int("repeat"); + int max_output_ids = ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); @@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args) return false; } + bool local_expert_masking = args.get_str("local_eid") != "-1"; + auto local_expert_masking_host = [&]() { + if(local_expert_masking) + { + auto local_eid = args.get_int_vec("local_eid"); + // std::vector v_ {num_experts, 0}; + ck_tile::HostTensor v_{{num_experts}}; + v_.SetZero(); + for(auto eid : local_eid) + { + if(eid >= num_experts) + { + throw std::runtime_error( + "local_eid larger than number of expert, please check"); + } + v_.mData[eid] = 1; + } + return v_; + } + else + // return std::vector{}; + return ck_tile::HostTensor{{1}}; + }(); + // tokens already considered batch size ck_tile::HostTensor topk_ids_host({tokens, topk}, {topk, 1}); ck_tile::HostTensor weights_host({tokens, topk}, {topk, 1}); @@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) sorted_expert_ids_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem local_expert_masking_dev( + local_expert_masking_host.get_element_space_size_in_bytes()); topk_ids_dev.ToDevice(topk_ids_host.data()); weights_dev.ToDevice(weights_host.data()); @@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args) { moe_buf_dev.ToDevice(moe_buf_host.data()); } + if(local_expert_masking) + local_expert_masking_dev.ToDevice(local_expert_masking_host.data()); - moe_sorting_trait trait{index_prec, weight_prec}; + moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking}; moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(), + local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() + : nullptr, sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(), sorted_expert_ids_dev.GetDeviceBuffer(), @@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args) warmup, repeat}; auto ms = moe_sorting(trait, karg, sc); - printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ", + printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ", index_prec.c_str(), weight_prec.c_str(), tokens, num_experts, - topk, - ms); + topk); + + if(local_expert_masking) + { + printf("local_eid:%s, ", args.get_str("local_eid").c_str()); + } + if(ms < 0) printf("not supported\n"); + else + printf("ms:%f, ", ms); fflush(stdout); if(ms < 0) { @@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args) int32_t ref_total_tokens_post_pad = 0; ck_tile::reference_moe_sorting(topk_ids_host, weights_host, + local_expert_masking_host, sorted_ids_ref, sorted_weights_ref, sorted_expert_ids_ref, ref_total_tokens_post_pad, num_experts, - unit_size); + unit_size, + local_expert_masking); rtn &= ck_tile::check_err( sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); rtn &= ck_tile::check_err(sorted_weights_host, @@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); } rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; + printf("total_tokens_post_pad:%d(%d), ", + ref_total_tokens_post_pad, + sorted_id_cnt_host.mData[0]); } - printf("valid:%s\n", rtn ? "y" : "n"); + printf("valid:%s", rtn ? "y" : "n"); + fflush(stdout); + if(!rtn) + printf(", (%d)", seed); + printf("\n"); fflush(stdout); return rtn; } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 723fb3f69f..abff24a669 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -3,6 +3,12 @@ #include "moe_sorting_api.hpp" +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + +#if !MOE_SORTING_USE_EX_KERNEL + #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \ @@ -17,6 +23,67 @@ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#else + +#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \ + constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ + constexpr bool sub_token_onshot = sub_token_onshot_; \ + constexpr bool local_expert_masking = local_expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemEx; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \ + if(row_ % 8 == 0) \ + { \ + MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \ + } \ + else if(row_ % 4 == 0) \ + { \ + MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \ + } \ + else if(row_ % 2 == 0) \ + { \ + MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \ + } + +#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \ + if(is_sub_token_onshot) \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \ + } + +#define MOE_SORTING_DISPATCH_EMASK_(row_) \ + if(is_local_expert_masking) \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, true) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, false) \ + } + +#endif + +#if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH(unroll_num_) \ if(a.num_experts <= 8) \ { \ @@ -38,11 +105,13 @@ { \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ } +#endif float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") { +#if !MOE_SORTING_USE_EX_KERNEL if(a.num_experts > 127) { printf("lds size exceed, only support experts <127 \n"); @@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi MOE_SORTING_DISPATCH(4); } } +#else + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); + auto sub_token_ = r_ - 2; + r_ = (r_ - 2) / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; + bool is_local_expert_masking = t.local_expert_masking; + (void)c_; + + MOE_SORTING_DISPATCH_EMASK_(r_); + // MOE_SORTING_DISPATCH_ETILE(0, 0); +#endif } return -1; } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp index 0cb393f7de..5bda4d368a 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -10,7 +10,8 @@ struct moe_sorting_trait { std::string index_type; - std::string weight_type; // currently always float + std::string weight_type; // currently always float + bool local_expert_masking; // if mask experts as local expert }; struct moe_sorting_args : public ck_tile::MoeSortingHostArgs diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index 3ff8a7332d..cf2c2e164b 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11 $EXE -t=1 -e=1 -k=1 $EXE -t=99 -e=2 -k=1 $EXE -t=333 -e=99 -k=13 +$EXE -t=11 -e=256 -k=5 +$EXE -t=64 -e=455 -k=8 +$EXE -t=777 -e=802 -k=99 +$EXE -t=4097 -e=906 -k=51 $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144 +$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11 +$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19 +$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33 +$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129 diff --git a/example/ck_tile/15_fused_moe/README.md b/example/ck_tile/15_fused_moe/README.md index b6ceabf351..089e1de78e 100644 --- a/example/ck_tile/15_fused_moe/README.md +++ b/example/ck_tile/15_fused_moe/README.md @@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator: // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // -// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated) // * this could be larger than actual, since actual tokens are on GPU // // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 7ca24c5c9a..805cd54878 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -3,6 +3,12 @@ #include "fused_moesorting.hpp" +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + +#if !MOE_SORTING_USE_EX_KERNEL + #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \ @@ -17,6 +23,24 @@ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#else +#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \ + constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ + constexpr bool sub_token_onshot = sub_token_onshot_; \ + using ms_problem = \ + ck_tile::MoeSortingProblemEx; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +#endif + +#if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH(unroll_num_) \ if(a.num_experts <= 8) \ { \ @@ -38,11 +62,13 @@ { \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ } +#endif float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") { +#if !MOE_SORTING_USE_EX_KERNEL if(a.num_experts > 127) { printf("lds size exceed, only support experts <127 \n"); @@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til MOE_SORTING_DISPATCH(4); } } +#else + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); + auto sub_token_ = r_ - 2; + r_ = (r_ - 2) / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; + (void)c_; + if(is_sub_token_onshot) + { + if(r_ % 8 == 0) + { + MOE_SORTING_DISPATCH_(8, true); + } + else if(r_ % 4 == 0) + { + MOE_SORTING_DISPATCH_(4, true); + } + else if(r_ % 2 == 0) + { + MOE_SORTING_DISPATCH_(2, true); + } + else + { + MOE_SORTING_DISPATCH_(1, true); + } + } + else + { + if(r_ % 8 == 0) + { + MOE_SORTING_DISPATCH_(8, false); + } + else if(r_ % 4 == 0) + { + MOE_SORTING_DISPATCH_(4, false); + } + else if(r_ % 2 == 0) + { + MOE_SORTING_DISPATCH_(2, false); + } + else + { + MOE_SORTING_DISPATCH_(1, false); + } + } + // MOE_SORTING_DISPATCH_ETILE(0, 0); +#endif } return -1; } diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp index 3851629cc2..47f0ba576b 100644 --- a/include/ck_tile/host/reference/reference_moe_sorting.hpp +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -14,12 +14,15 @@ namespace ck_tile { template CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, const HostTensor& weights, + const HostTensor& local_expert_mask, HostTensor& p_sorted_token_ids, HostTensor& sorted_weight, HostTensor& sorted_expert_ids, index_t& unit_cnt, const index_t experts, - const index_t unit_size) + const index_t unit_size, + bool local_expert_masking, + bool skip_experts_with_zero_token = true) { const index_t num_token = topk_ids.mDesc.get_lengths()[0]; const index_t topk = topk_ids.mDesc.get_lengths()[1]; @@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, #endif std::vector> expert_token_weights( experts, std::vector(unit_size, 0)); + // count number of unit-size slices in this expert std::vector expert_slices(experts, 1); + // count the tokens used in this expert std::vector expert_slice_idxs(experts, 0); + // TODO: above 2 buffer seems duplicated for(index_t t = 0; t < num_token; t++) { @@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, IndexType* out_tokens = p_sorted_token_ids.data(); WeightType* out_weights = sorted_weight.data(); IndexType* out_expert_id = sorted_expert_ids.data(); + int curr_expert_id = 0; for(index_t e = 0; e < experts; e++) { + if(local_expert_masking) + { + if(local_expert_mask(e) == 0) + continue; + } + if(skip_experts_with_zero_token) + { + if(expert_slice_idxs[e] == 0) + { + curr_expert_id++; + continue; + } + } + memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size); out_tokens += expert_slices[e] * unit_size; memcpy(out_weights, @@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, for(index_t s = 0; s < expert_slices[e]; s++) { - out_expert_id[s] = e; + out_expert_id[s] = curr_expert_id; unit_cnt++; } out_expert_id += expert_slices[e]; + curr_expert_id++; } unit_cnt *= unit_size; return; diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index 3ffb0a9ca2..ddb64a2189 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp" #include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" +#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp" @@ -14,7 +15,6 @@ #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" -#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp index a7eeb3c0e3..efa1ccb311 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp @@ -22,7 +22,7 @@ // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // -// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated) // * this could be larger than actual, since actual tokens are on GPU // // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 30e68996b6..340f6cb9e5 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -15,6 +15,10 @@ namespace ck_tile { #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ static_cast(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24)) +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + // clang-format off // [indexing implementation-1] // using M_a as constexpr block_size to partition all tokens into different slices @@ -28,7 +32,7 @@ namespace ck_tile { // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // -// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated) // * this could be larger than actual, since actual tokens are on GPU // // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] @@ -55,6 +59,34 @@ namespace ck_tile { // num_tokens_post_padded_ptr : [28] // num_sorted_tiles_ptr : [7] // +// skip_experts_with_zero_tokens(SkipExpertsWithZeroTokens) +// if enabled, the expert with no tokens will be skipped, in stead of padding to at least 1 unit_size(M_a) +// +// (pack below tensor, skip element marked with `-`) +// Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y +// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] +// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| +// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] +// +// +// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 5] +// num_tokens_post_padded_ptr : [24] +// +// * local_expert_mask : indicate local expert mask used on current GPU (used for EP case) +// and modify the output expert-ID, because we will only have enbaled expert on specific GPU. +// we call expert input to this kernel as "global expert id", output as "local expert id" +// +// * local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4) +// +// (pack below tensor, skip element marked with `-`) +// Y Y Y Y - - - - Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y +// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] +// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| +// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] +// +// sorted_expert_ids_ptr : [0, 1, 2, 2, 3] (note original it was exper-id= 0, 2, 3, 5, but we produce "local expert id") +// num_tokens_post_padded_ptr : [20] +// // * different from vLLM // 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id // 2)need sorted_weight_ptr @@ -67,10 +99,80 @@ namespace ck_tile { // 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) // // max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) + + +CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int num_experts_) +{ + /* num_experts + 1 + * +--------------------------------------+ + * | | + * | | + * | | * -> sub-tokens + * | | + * | | + * +--------------------------------------+ + * | | 2 -> cumsum buffer + * +--------------------------------------+ + * + */ + int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here + int smem_rows = [&](){ + index_t target_occupancy_ = 2; + constexpr index_t total_ = 65536 / sizeof(int); + constexpr index_t sub_unroll = 8; + constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt + // at lease 2 lines, one for sub_token unroll, one for cumsum + // should be enough + if ((total_ / target_occupancy_) < ((cumsum_bufs+sub_unroll) * smem_cols)) { + if ((total_ / 1) < ((cumsum_bufs+sub_unroll) * smem_cols)) + throw std::runtime_error("too many num_experts, can't allocate smem"); + target_occupancy_ = 1; + } + int r = total_ / target_occupancy_ / smem_cols; + + // round to sub_unroll multipl + int r_for_sub_token = r - cumsum_bufs; + r_for_sub_token = min(r_for_sub_token, num_tokens_); + r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll; + r_for_sub_token = max(r_for_sub_token, 1); + + if(r_for_sub_token > 1) + { + int r_unroll_ = r_for_sub_token / sub_unroll; + + + // round to 1x/2x/4x/8x number of sub_unroll + int clz_ = __builtin_clz(r_unroll_); // 0b1:31 0b2:30, 0b3:30, 0b4:29 + int mask_ = (1 << (31 - clz_)) - 1; + + + mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most + mask_ = ~mask_; + //printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout); + + r_for_sub_token = (r_unroll_ & mask_) * sub_unroll; + } + + // final check + if( (r_for_sub_token + cumsum_bufs * smem_cols * target_occupancy_ ) >= total_ ) { + throw std::runtime_error("can't run this kernel, request LDS over size"); + } + + return r_for_sub_token + cumsum_bufs; + }(); + + // printf("r:%d, c:%d\n", smem_rows, smem_cols); + + return ck_tile::make_tuple(smem_rows, smem_cols); +} + struct MoeSortingHostArgs { const void* p_topk_ids; // [token, topk] const void* p_weights; // [token, topk] + + const void* p_local_expert_mask; + void* p_sorted_token_ids; void* p_sorted_weights; void* p_sorted_expert_ids; @@ -101,6 +203,7 @@ struct MoeSortingKernel { const void* p_topk_ids; const void* p_weights; + const void* p_local_expert_mask; void* p_sorted_token_ids; void* p_sorted_weights; void* p_sorted_expert_ids; @@ -111,8 +214,11 @@ struct MoeSortingKernel index_t moe_buf_bytes; index_t tokens_per_thread; + index_t smem_rows; mdiv unit_size_mdiv; mdiv topk_mdiv; + mdiv expert_mdiv; + // mdiv sub_tokens_mdiv; }; CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) @@ -123,15 +229,25 @@ struct MoeSortingKernel CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h) { +#if MOE_SORTING_USE_EX_KERNEL + (void)h; + return dim3(256); +#else return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size())); +#endif } // in byte CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) { +#if MOE_SORTING_USE_EX_KERNEL + auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts); + return smem_rows * smem_cols * sizeof(int); +#else const auto blocks = BlockSize(h); // usually num_experts is power of 2, we pad 1 dword here for the row-size return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t); +#endif } CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) @@ -139,6 +255,7 @@ struct MoeSortingKernel Kargs k; k.p_topk_ids = h.p_topk_ids; k.p_weights = h.p_weights; + k.p_local_expert_mask = h.p_local_expert_mask; k.p_sorted_token_ids = h.p_sorted_token_ids; k.p_sorted_weights = h.p_sorted_weights; k.p_sorted_expert_ids = h.p_sorted_expert_ids; @@ -152,10 +269,18 @@ struct MoeSortingKernel k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x); k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; k.topk_mdiv = mdiv{static_cast(h.topk)}; + k.smem_rows = [&](){ + auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts); + (void) c_; + return r_; + }(); + k.expert_mdiv = mdiv{static_cast(h.num_experts)}; + // k.sub_tokens_mdiv = mdiv{static_cast(k.smem_rows - 1)}; return k; } - // [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....] + // [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....] + // NOTE: wave_size need at least be 16!! dpp 16 is one row template __device__ inline void wave_cumsum(data_t& thread_data) const { @@ -196,6 +321,40 @@ struct MoeSortingKernel bank_mask, bound_ctrl))); // row_shr:4 } + if constexpr(wave_size == 8) { + + // wave-size=8 need one extra shift + thread_data = + reduce_op(thread_data, + __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x118, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:8 +#if 0 + constexpr int bank_mask_0_7 = 0b1100; + auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; }; + thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t, + __builtin_amdgcn_update_dpp(0, /* old value */ + __builtin_bit_cast(int, thread_data), + 0x157, + row_mask, + bank_mask_0_7, + bound_ctrl))// row_newbcast:7 + ); +#else + data_t xxx =__builtin_bit_cast(data_t, + __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x157, + row_mask, + bank_mask, + bound_ctrl)); // row_newbcast:7 + + data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx; + thread_data = thread_data - yyy; +#endif + + } if constexpr(wave_size > 8) { thread_data = @@ -224,6 +383,36 @@ struct MoeSortingKernel } } + // reduce single pixel within a wave + template + __device__ static constexpr T wave_reduce(T local, F reduce_f, number = {}) + { + // constexpr int wave_size = 64; + // constexpr int reduce_stage = 6; // 1<<6=64 + // clang-format off + constexpr int reduce_stage = [](){ + if constexpr(wave_size_ == 2) return 1; + else if constexpr(wave_size_ == 4) return 2; + else if constexpr(wave_size_ == 8) return 3; + else if constexpr(wave_size_ == 16) return 4; + else if constexpr(wave_size_ == 32) return 5; + else if constexpr(wave_size_ == 64) return 6; + else return 0; + }(); + // clang-format on + T v_local = local; +#pragma unroll reduce_stage + for(int i_stage = 0; i_stage < reduce_stage; i_stage++) + { + int src_lane = __lane_id() ^ (1 << i_stage); + int32_t v_remote_tmp = + __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast(v_local)); + T v_remote = bit_cast(v_remote_tmp); + v_local = reduce_f(v_local, v_remote); + } + return v_local; + } + CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const { return row * total_col + col; @@ -257,37 +446,37 @@ struct MoeSortingKernel index_t* shared_mem = reinterpret_cast(smem); index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts) - index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts+1); // 1: (num_experts + 1) + index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts + 1); // 1: (num_experts + 1) for(int i = 0; i < num_experts; ++i) { - tokens_cnts[calc_index(num_experts+1, tid + 1, i)] = 0; + tokens_cnts[calc_index(num_experts + 1, tid + 1, i)] = 0; } #pragma unroll Problem_::InternalLoadUnroll for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[calc_index(num_experts+1, tid + 1, topk_id[i])]; + ++tokens_cnts[calc_index(num_experts + 1, tid + 1, topk_id[i])]; } __syncthreads(); #if 1 if(tid < num_experts) { - tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0; + tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0; index_t local_c[8]; index_t prev_c = 0; // TODO: manually unroll. pragma unroll does not work well when we have dependency - for(int i = 1; i <= static_cast(blockDim.x); i+= 8) + for(int i = 1; i <= static_cast(blockDim.x); i += 8) { - local_c[0] = tokens_cnts[calc_index(num_experts+1, i + 0, tid)]; - local_c[1] = tokens_cnts[calc_index(num_experts+1, i + 1, tid)]; - local_c[2] = tokens_cnts[calc_index(num_experts+1, i + 2, tid)]; - local_c[3] = tokens_cnts[calc_index(num_experts+1, i + 3, tid)]; - local_c[4] = tokens_cnts[calc_index(num_experts+1, i + 4, tid)]; - local_c[5] = tokens_cnts[calc_index(num_experts+1, i + 5, tid)]; - local_c[6] = tokens_cnts[calc_index(num_experts+1, i + 6, tid)]; - local_c[7] = tokens_cnts[calc_index(num_experts+1, i + 7, tid)]; + local_c[0] = tokens_cnts[calc_index(num_experts + 1, i + 0, tid)]; + local_c[1] = tokens_cnts[calc_index(num_experts + 1, i + 1, tid)]; + local_c[2] = tokens_cnts[calc_index(num_experts + 1, i + 2, tid)]; + local_c[3] = tokens_cnts[calc_index(num_experts + 1, i + 3, tid)]; + local_c[4] = tokens_cnts[calc_index(num_experts + 1, i + 4, tid)]; + local_c[5] = tokens_cnts[calc_index(num_experts + 1, i + 5, tid)]; + local_c[6] = tokens_cnts[calc_index(num_experts + 1, i + 6, tid)]; + local_c[7] = tokens_cnts[calc_index(num_experts + 1, i + 7, tid)]; local_c[0] += prev_c; local_c[1] += local_c[0]; @@ -299,51 +488,57 @@ struct MoeSortingKernel local_c[7] += local_c[6]; prev_c = local_c[7]; - tokens_cnts[calc_index(num_experts+1, i + 0, tid)] = local_c[0]; - tokens_cnts[calc_index(num_experts+1, i + 1, tid)] = local_c[1]; - tokens_cnts[calc_index(num_experts+1, i + 2, tid)] = local_c[2]; - tokens_cnts[calc_index(num_experts+1, i + 3, tid)] = local_c[3]; - tokens_cnts[calc_index(num_experts+1, i + 4, tid)] = local_c[4]; - tokens_cnts[calc_index(num_experts+1, i + 5, tid)] = local_c[5]; - tokens_cnts[calc_index(num_experts+1, i + 6, tid)] = local_c[6]; - tokens_cnts[calc_index(num_experts+1, i + 7, tid)] = local_c[7]; + tokens_cnts[calc_index(num_experts + 1, i + 0, tid)] = local_c[0]; + tokens_cnts[calc_index(num_experts + 1, i + 1, tid)] = local_c[1]; + tokens_cnts[calc_index(num_experts + 1, i + 2, tid)] = local_c[2]; + tokens_cnts[calc_index(num_experts + 1, i + 3, tid)] = local_c[3]; + tokens_cnts[calc_index(num_experts + 1, i + 4, tid)] = local_c[4]; + tokens_cnts[calc_index(num_experts + 1, i + 5, tid)] = local_c[5]; + tokens_cnts[calc_index(num_experts + 1, i + 6, tid)] = local_c[6]; + tokens_cnts[calc_index(num_experts + 1, i + 7, tid)] = local_c[7]; } } #else - // TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic + // TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future + // heuristic { if(tid < num_experts) - tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0; - for(int i = 0; i < num_experts; i+=8) { + tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0; + for(int i = 0; i < num_experts; i += 8) + { index_t local_c[8]; - #pragma unroll - for(int j = 0; j < 8; j++) { - local_c[j] = tokens_cnts[calc_index(num_experts+1, tid+1, i+j)]; +#pragma unroll + for(int j = 0; j < 8; j++) + { + local_c[j] = tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)]; } - #pragma unroll - for(int j = 0; j < 8; j++) { +#pragma unroll + for(int j = 0; j < 8; j++) + { wave_cumsum(local_c[j]); } - #pragma unroll - for(int j = 0; j < 8; j++) { - tokens_cnts[calc_index(num_experts+1, tid+1, i+j)] = local_c[j]; +#pragma unroll + for(int j = 0; j < 8; j++) + { + tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)] = local_c[j]; } } } #endif __syncthreads(); - if constexpr (Problem::ExpertTile == 0) { + if constexpr(Problem::ExpertTile == 0) + { if(tid == 0) { cumsum[0] = 0; for(int i = 1; i <= num_experts; ++i) { auto current_units = [&]() { - index_t x_ = tokens_cnts[calc_index(num_experts+1, blockDim.x, i - 1)] + - unit_size_mdiv.divisor - 1; + index_t x_ = tokens_cnts[calc_index(num_experts + 1, blockDim.x, i - 1)] + + unit_size_mdiv.divisor - 1; index_t y_ = unit_size_mdiv.div(x_); return max(y_, 1) * unit_size_mdiv.divisor; }(); @@ -351,20 +546,24 @@ struct MoeSortingKernel } *p_total_tokens_post_pad = cumsum[num_experts]; } - } else { - // TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert) - // for simplicity, not check experts here. - int local_cnt = tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)]; + } + else + { + // TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= + // expert) for simplicity, not check experts here. + int local_cnt = tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)]; int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1); int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor; - int local_cumsum = padded_tokens_per_expert; + int local_cumsum = padded_tokens_per_expert; wave_cumsum(local_cumsum); - if(tid == (num_experts - 1)) { - cumsum[0] = 0; + if(tid == (num_experts - 1)) + { + cumsum[0] = 0; *p_total_tokens_post_pad = local_cumsum; } - if(tid < num_experts) { + if(tid < num_experts) + { cumsum[tid + 1] = local_cumsum; } } @@ -373,7 +572,7 @@ struct MoeSortingKernel if(tid < num_experts) { int e_start = cumsum[tid]; - int e_end = cumsum[tid + 1]; + int e_end = cumsum[tid + 1]; for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor) { p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid; @@ -383,8 +582,8 @@ struct MoeSortingKernel #pragma unroll Problem_::InternalLoadUnroll for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - index_t expert_id = topk_id[i]; - index_t local_cnt = tokens_cnts[calc_index(num_experts+1, tid, expert_id)]; + index_t expert_id = topk_id[i]; + index_t local_cnt = tokens_cnts[calc_index(num_experts + 1, tid, expert_id)]; index_t rank_post_pad = local_cnt + cumsum[expert_id]; #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID uint32_t curr_token_id, curr_topk_id; @@ -393,16 +592,17 @@ struct MoeSortingKernel #else p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i); #endif - p_sorted_weights[rank_post_pad] = weights[i]; - tokens_cnts[calc_index(num_experts+1, tid, expert_id)] = local_cnt+1; + p_sorted_weights[rank_post_pad] = weights[i]; + tokens_cnts[calc_index(num_experts + 1, tid, expert_id)] = local_cnt + 1; } - if constexpr (Problem::ExpertTile == 0) { + if constexpr(Problem::ExpertTile == 0) + { const index_t prefill_token = topk_mdiv.div(numel); if(tid < num_experts) { index_t expert_offset = - cumsum[tid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)]; + cumsum[tid] + tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)]; index_t expert_end = cumsum[tid + 1]; while(expert_offset < expert_end) { @@ -417,16 +617,19 @@ struct MoeSortingKernel } } } - else { + else + { const index_t prefill_token = topk_mdiv.div(numel); // TODO: only support expert-tile like 8, 16, 32 static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile; { - index_t eid = tid / experts_per_wave; - index_t expert_offset = - cumsum[eid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, eid)] + tid % experts_per_wave; + index_t eid = tid / experts_per_wave; + index_t expert_offset = cumsum[eid] + + tokens_cnts[calc_index(num_experts + 1, blockDim.x, eid)] + + tid % experts_per_wave; index_t expert_end = cumsum[eid + 1]; - if(eid < num_experts) { + if(eid < num_experts) + { while(expert_offset < expert_end) { #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID @@ -436,10 +639,363 @@ struct MoeSortingKernel p_sorted_token_ids[expert_offset] = prefill_token; #endif p_sorted_weights[expert_offset] = static_cast(0.0); - expert_offset+=experts_per_wave; + expert_offset += experts_per_wave; } } - } + } + } + } + + // only support index_t, and single pixel access + struct simple_smem_indexer + { + index_t* smem; + index_t row_stride; + + // this is 2D + CK_TILE_DEVICE simple_smem_indexer(index_t* smem_, index_t row_stride_) + : smem(smem_), row_stride(row_stride_) + { + } + CK_TILE_DEVICE const index_t& operator()(index_t i_row, index_t i_col) const + { + return smem[i_row * row_stride + i_col]; + } + CK_TILE_DEVICE index_t& operator()(index_t i_row, index_t i_col) + { + return smem[i_row * row_stride + i_col]; + } + + // this is 1D or linear + CK_TILE_DEVICE simple_smem_indexer(index_t* smem_) : smem(smem_), row_stride(0) {} + CK_TILE_DEVICE const index_t& operator()(index_t idx) const { return smem[idx]; } + CK_TILE_DEVICE index_t& operator()(index_t idx) { return smem[idx]; } + }; + + CK_TILE_DEVICE void + moe_align_block_size_kernel_ex(const IndexType* __restrict__ topk_id, + const WeightType* __restrict__ weights, + const IndexType* __restrict__ local_expert_mask, + index_t* p_sorted_token_ids, + WeightType* p_sorted_weights, + index_t* p_sorted_expert_ids, + index_t* p_total_tokens_post_pad, + const index_t num_experts, + const index_t tokens, + const mdiv unit_size_mdiv, + const mdiv topk_mdiv, + const mdiv expert_mdiv, + const index_t smem_rows, + void* smem) const + { + const index_t tid = static_cast(threadIdx.x); + const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize); + const index_t lid = __lane_id(); + constexpr index_t block_size = 256; // blockDim.x; + const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor; + const index_t topk = topk_mdiv.divisor; + auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; + + const index_t smem_cols = num_experts + 1; + + simple_smem_indexer smem_cumsum{reinterpret_cast(smem) + 0}; + simple_smem_indexer smem_cumdup{reinterpret_cast(smem) + smem_cols}; + simple_smem_indexer smem_tokens{reinterpret_cast(smem) + 2 * smem_cols, + smem_cols}; + + // #pragma unroll 8 + for(int i = tid; i < (sub_tokens * num_experts); i += block_size) + { + uint32_t curr_token_id, curr_expert_id; + expert_mdiv.divmod(i, curr_token_id, curr_expert_id); + smem_tokens(curr_token_id, curr_expert_id) = 0; + } + __syncthreads(); + + for(int i_token = 0; i_token < tokens; i_token += sub_tokens) + { + // NOTE: below for loop can't have barrier inside!! + for(int i = tid; i < (sub_tokens * topk); i += block_size) + { + uint32_t curr_token_id, curr_topk_id; + topk_mdiv.divmod(i, curr_token_id, curr_topk_id); + int i_t = i_token + curr_token_id; + + if(i_t < tokens) + { + int eid = topk_id[i_t * topk + curr_topk_id]; + + if constexpr(Problem::SubTokenOneShot) + smem_tokens(curr_token_id, eid) = curr_topk_id + 1; + else + smem_tokens(curr_token_id, eid)++; + } + __builtin_amdgcn_s_waitcnt(0xc07f); + } + __syncthreads(); // make sure different i_token iteration not overlap by different wave + } + + // counting + if(tid == 0) + { + smem_cumsum(0) = 0; + // smem_cumdup(0) = 0; + } + + { + constexpr int lane_group_sz = 8; + int lane_group_id = tid / lane_group_sz; + int lane_group_os = tid % lane_group_sz; + constexpr int lane_group_nm = block_size / lane_group_sz; + + for(int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm) + { + index_t local_c[Problem::SubTokenTile]; + index_t cnt = 0; + + for(int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile) + { +#pragma unroll Problem::SubTokenTile + for(int j = 0; j < Problem::SubTokenTile; j++) + { + local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e); + if constexpr(Problem::SubTokenOneShot) + { + local_c[j] = local_c[j] != 0 ? 1 : 0; + } + } + +#pragma unroll Problem::SubTokenTile + for(int j = 0; j < Problem::SubTokenTile; j++) + { + cnt += wave_reduce(local_c[j], f_sum, number<8>{}); + } + } + if(lane_group_os == 0) + smem_cumsum(i_e + 1) = cnt; + } + } + + if constexpr(Problem::LocalExpertMasking) + { + smem_cumdup(0) = 0; + for(int i_e = tid; i_e < num_experts; i_e += block_size) + { + // reuse this buffer + smem_cumdup(i_e + 1) = local_expert_mask[i_e]; + } + } + + __syncthreads(); + + { + if(wid == 0) + { + // NOTE: under this block can never use __syncthreads! + int i_e_ = 0; + int local_cumsum_ = 0; + for(; i_e_ < num_experts; i_e_ += warpSize) + { + int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0); + int local_cnt = smem_cumsum(i_e_ + lid + 1); + int blocks_pers_expert = + unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1); + + int pre_cumsum_masking = [&]() { + if constexpr(Problem::LocalExpertMasking) + return smem_cumdup(lid == 0 ? i_e_ : 0); + else + return 0; // not used + }(); + int local_masking = [&]() { + if constexpr(Problem::LocalExpertMasking) + return smem_cumdup(i_e_ + lid + 1); + else + return 0; // not used + }(); + int padded_tokens_per_expert = [&]() { + int x_ = [&]() { + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + // if local_cnt is zero, blocks_pers_expert will be zero + // this is what we want to achieve + return blocks_pers_expert * unit_size_mdiv.divisor; + } + else + { + return max(blocks_pers_expert, 1) * unit_size_mdiv.divisor; + } + }(); + if constexpr(Problem::LocalExpertMasking) + { + return local_masking ? x_ : 0; + } + else + return x_; + }(); + + local_cumsum_ = padded_tokens_per_expert; + local_cumsum_ += pre_cumsum_; // note pre_cumsum must be added after local + // cumsum padded in case local cumsum is zero, but + // pre_sumsum has value, which will result int + // zero local cumsum(but we want at least padded) + wave_cumsum(local_cumsum_); + + if((i_e_ + lid) < num_experts) + smem_cumsum(i_e_ + lid + 1) = local_cumsum_; + + if constexpr(Problem::LocalExpertMasking) + { + local_masking += pre_cumsum_masking; + wave_cumsum(local_masking); + if((i_e_ + lid) < num_experts) + smem_cumdup(i_e_ + lid + 1) = local_masking; + } + + // NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt() + // for above write however __syncthreads will cause barrier with waves other + // than 0(which is not we want) + __builtin_amdgcn_s_waitcnt(0xc07f); + } + if((lid + i_e_ - warpSize) == (num_experts - 1)) + { + *p_total_tokens_post_pad = local_cumsum_; + } + } + __syncthreads(); + } + + for(int i_e = tid; i_e < num_experts; i_e += block_size) + { + int e_start = smem_cumsum(i_e); + int e_end = smem_cumsum(i_e + 1); + + int expert_id = [&]() { + if constexpr(Problem::LocalExpertMasking) + { + // local expert id from cumsum + return smem_cumdup(i_e); + } + else + return i_e; + }(); + + smem_cumdup(i_e) = e_start; // duplicate cumsum for later use + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + if(e_start == e_end) // skip zero token expert + continue; + } + + if constexpr(Problem::LocalExpertMasking) + { + if(local_expert_mask[i_e] == 0) + continue; + } + + for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor) + { + p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id; + } + } + smem_cumdup(num_experts) = smem_cumsum(num_experts); + + // fill the p_sorted_token_ids/p_sorted_weights + for(int i_token = 0; i_token < tokens; i_token += sub_tokens) + { + if constexpr(!Problem::SubTokenOneShot) + { + // clear every time + for(int i = tid; i < (sub_tokens * num_experts); i += block_size) + { + uint32_t curr_token_id, curr_expert_id; + expert_mdiv.divmod(i, curr_token_id, curr_expert_id); + smem_tokens(curr_token_id, curr_expert_id) = 0; + } + __syncthreads(); + + // load again + for(int i = tid; i < (sub_tokens * topk); i += block_size) + { + uint32_t curr_token_id_, curr_topk_id_; + topk_mdiv.divmod(i, curr_token_id_, curr_topk_id_); + int curr_token_id = static_cast(curr_token_id_); + int curr_topk_id = static_cast(curr_topk_id_); + int i_t = i_token + curr_token_id; + if(i_t < tokens) + { + int eid = topk_id[i_t * topk + curr_topk_id]; + smem_tokens(curr_token_id, eid) = curr_topk_id + 1; // at least 1 + } + } + __syncthreads(); + } + + { + constexpr int lane_group_sz = 8; + int lane_group_id = tid / lane_group_sz; + int lane_group_os = tid % lane_group_sz; + constexpr int lane_group_nm = block_size / lane_group_sz; + for(int eid = lane_group_id; eid < num_experts; eid += lane_group_nm) + { + if constexpr(Problem::LocalExpertMasking) + { + if(local_expert_mask[eid] == 0) + continue; + } + int position = smem_cumsum(eid); + for(int i_sub_token = lane_group_os; i_sub_token < sub_tokens; + i_sub_token += lane_group_sz) + { + auto x = smem_tokens(i_sub_token, eid); + + int local_cnt_cache = x != 0 ? 1 : 0; + int local_cnt = local_cnt_cache; + wave_cumsum(local_cnt); + if(x != 0) + { + // now x is topk value +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[position + local_cnt - 1] = + MOE_SORTING_MOCK_ID(i_token + i_sub_token, x - 1); +#else + p_sorted_token_ids[position + local_cnt - 1] = i_token + i_sub_token; +#endif + p_sorted_weights[position + local_cnt - 1] = + weights[(i_token + i_sub_token) * topk + x - 1]; + } + + int remote_cnt = __builtin_amdgcn_ds_bpermute( + (lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt); + + position += remote_cnt; + } + smem_cumsum(eid) = position; + } + } + __syncthreads(); + } + + // add the skip number + for(int eid = tid; eid < num_experts; eid += block_size) + { + int e_start = smem_cumsum(eid); + int e_end = smem_cumdup(eid + 1); + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + if(e_start == e_end) // skip zero token expert + continue; + } + while(e_start < e_end) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start] = MOE_SORTING_MOCK_ID(tokens, topk); +#else + p_sorted_token_ids[e_start] = tokens; +#endif + p_sorted_weights[e_start] = static_cast(0.0); + e_start++; + } } } @@ -456,6 +1012,24 @@ struct MoeSortingKernel } const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; extern __shared__ char smem[]; +#if MOE_SORTING_USE_EX_KERNEL + (void)numel; + return moe_align_block_size_kernel_ex( + static_cast(kargs.p_topk_ids), + static_cast(kargs.p_weights), + static_cast(kargs.p_local_expert_mask), + static_cast(kargs.p_sorted_token_ids), + static_cast(kargs.p_sorted_weights), + static_cast(kargs.p_sorted_expert_ids), + static_cast(kargs.p_total_tokens_post_pad), + kargs.num_experts, + kargs.tokens, + kargs.unit_size_mdiv, + kargs.topk_mdiv, + kargs.expert_mdiv, + kargs.smem_rows, + smem); +#else return moe_align_block_size_kernel(static_cast(kargs.p_topk_ids), static_cast(kargs.p_weights), static_cast(kargs.p_sorted_token_ids), @@ -468,6 +1042,7 @@ struct MoeSortingKernel kargs.unit_size_mdiv, kargs.topk_mdiv, smem); +#endif } }; diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp new file mode 100644 index 0000000000..15effe7118 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include +#include + +namespace ck_tile { + +template +struct MoeSortingProblem +{ + // TODO: this kernel only support warp per row + using WeightType = remove_cvref_t; + using IndexType = remove_cvref_t; + + static constexpr index_t WarpSize = get_warp_size(); + static constexpr index_t WarpsPerBlock = 1; + static constexpr index_t InternalLoadUnroll = + InternalLoadUnroll_; // TODO: need better design(like tile size) + static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out +}; + +template +struct MoeSortingProblemEx +{ + // TODO: this kernel only support warp per row + using WeightType = remove_cvref_t; + using IndexType = remove_cvref_t; + + static constexpr index_t WarpSize = get_warp_size(); + static constexpr index_t WarpsPerBlock = 1; + static constexpr index_t SubTokenTile = SubTokenTile_; + static constexpr bool SubTokenOneShot = SubTokenOneShot_; + static constexpr bool LocalExpertMasking = LocalExpertMasking_; + static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_; + static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8); + static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp deleted file mode 100644 index 50005c4402..0000000000 --- a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include -#include - -namespace ck_tile { - -template -struct MoeSortingProblem -{ - // TODO: this kernel only support warp per row - using WeightType = remove_cvref_t; - using IndexType = remove_cvref_t; - - static constexpr index_t WarpSize = get_warp_size(); - static constexpr index_t WarpsPerBlock = 1; - static constexpr index_t InternalLoadUnroll = - InternalLoadUnroll_; // TODO: need better design(like tile size) - static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out -}; -} // namespace ck_tile From b5ca008d62f7f4d0aa23735acfa7dfc4bc682f78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mirza=20Halil=C4=8Devi=C4=87?= <109971222+mirza-halilcevic@users.noreply.github.com> Date: Tue, 11 Feb 2025 17:07:24 +0100 Subject: [PATCH 3/7] Introduce gemm_softmax_gemm to codegen (#1542) * Introduce ck_host library and gemm_softmax_gemm. * Minor refactor. * Add descriptor to gemm_softmax_gemm. * Bugfix. * Revert ck_host library. * fix clang format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin --- .../operation.hpp | 61 +++ .../problem.hpp | 47 ++ .../host/device_gemm_multiple_d/operation.hpp | 2 + codegen/include/ck/host/operation/gemm.hpp | 20 + codegen/include/ck/host/types.hpp | 15 + .../src/device_batched_gemm_softmax_gemm.cpp | 38 ++ ...mm_softmax_gemm_operation_xdl_cshuffle.cpp | 408 ++++++++++++++++++ ...gemm_multiple_d_operation_xdl_cshuffle.cpp | 102 +++-- codegen/src/types.cpp | 20 + codegen/test/rtc/include/rtc/hip.hpp | 1 + example/ck_tile/03_gemm/run_gemm_example.inc | 74 ++-- ...batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 373 +++++++++++++++- 12 files changed, 1071 insertions(+), 90 deletions(-) create mode 100644 codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp create mode 100644 codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp create mode 100644 codegen/src/device_batched_gemm_softmax_gemm.cpp create mode 100644 codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp new file mode 100644 index 0000000000..301df0a529 --- /dev/null +++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/operation/gemm.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// defines all values need for an instance of fwd conv +struct Operation_Xdl_CShuffle +{ + // returns a vector of instances, only given fusion operators: will use default problem spec + static std::vector> + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, given a problem spec and fusion operators + static std::vector + CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue); + TensorDesc A{}; + TensorDesc B{}; + TensorDesc B1{}; + TensorDesc C{}; + DataType acc = DataType::Float; + DataType cs_type = DataType::Half; + std::string a_elem_op = PassThrough; + std::string b_elem_op = PassThrough; + std::string b1_elem_op = PassThrough; + std::string c_elem_op = PassThrough; + std::string acc_elem_op = Scale; + std::string prologue = ""; + std::string epilogue = ""; + std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; + // tuning parameters + operation::TileDescGemmGemm tile_desc{}; + operation::BlockTransferDesc a_block_transfer{}; + operation::BlockTransferDesc b0_block_transfer{}; + operation::BlockTransferDesc b1_block_transfer{}; + operation::CShuffleDesc cshuffle{}; + operation::CBlockTransferDesc c_block_transfer{}; + + bool mask_out_upper_triangle = false; + + // functions to update fusion operators if provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + /**constexpr**/ bool + IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_); + // returns a templated instance + Solution ToSolution() const; +}; + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp new file mode 100644 index 0000000000..428034a3ba --- /dev/null +++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// defines the problem specification for a GEMM operation +struct Problem +{ + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + std::size_t O = 0; + bool TransA = false; + bool TransB = false; + bool TransB1 = false; + bool TransC = false; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType B1DataType = DataType::Half; + DataType CDataType = DataType::Half; + std::string AElementOp = PassThrough; + std::string BElementOp = PassThrough; + std::string B1ElementOp = PassThrough; + std::string CElementOp = PassThrough; + std::string AccElementOp = Scale; + + // returns the correct device op file for the operation + std::string GetIncludeHeader() const; + + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const; +}; + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp index 359da7d8cf..e5eeb6be15 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp @@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle operation::BlockTransferDesc b_block_transfer{}; operation::CShuffleDesc cshuffle{}; operation::CBlockTransferDesc c_block_transfer{}; + LoopScheduler loop_scheduler{}; + PipelineVersion pipeline_version{}; // functions to update fusion operators if provided void update_prologue(const std::string& prologue); diff --git a/codegen/include/ck/host/operation/gemm.hpp b/codegen/include/ck/host/operation/gemm.hpp index 84ef92f0a0..5a51a0002e 100644 --- a/codegen/include/ck/host/operation/gemm.hpp +++ b/codegen/include/ck/host/operation/gemm.hpp @@ -23,6 +23,26 @@ struct TileDesc int n_Xdl_per_wave = 0; int num_gemmk_prefetch_stage = 0; }; + +struct TileDescGemmGemm +{ + int block_size = 0; + int gemm01_m_per_block = 0; + int gemm0_n_per_block = 0; + int gemm0_k_per_block = 0; + int gemm1_n_per_block = 0; + int gemm1_k_per_block = 0; + int ak1 = 0; + int bk1 = 0; + int b1k1 = 0; + int m_per_XDL = 0; + int n_per_XDL = 0; + int gemm0_m_Xdl_per_wave = 0; + int gemm0_n_Xdl_per_wave = 0; + int gemm1_n_Xdl_per_wave = 0; + int num_gemmk_prefetch_stage = 0; +}; + struct BlockTransferDesc { std::string thread_cluster_length = ""; diff --git a/codegen/include/ck/host/types.hpp b/codegen/include/ck/host/types.hpp index 8bad7bf89c..b05e134176 100644 --- a/codegen/include/ck/host/types.hpp +++ b/codegen/include/ck/host/types.hpp @@ -66,6 +66,20 @@ enum class GemmType }; std::string ToString(GemmType gt); +enum class LoopScheduler +{ + Default, + Interwave, +}; +std::string ToString(LoopScheduler ls); + +enum class PipelineVersion +{ + v1, + v2 +}; +std::string ToString(PipelineVersion pv); + struct TensorDesc { DataType element; @@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...}); constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; +constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale"; } // namespace host } // namespace ck diff --git a/codegen/src/device_batched_gemm_softmax_gemm.cpp b/codegen/src/device_batched_gemm_softmax_gemm.cpp new file mode 100644 index 0000000000..cf140ead1d --- /dev/null +++ b/codegen/src/device_batched_gemm_softmax_gemm.cpp @@ -0,0 +1,38 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// return the relevant device op file based on the operation +std::string Problem::GetIncludeHeader() const +{ + return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"; +} + +// returns templated instances when provided with a problem specification +std::vector Problem::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const +{ + if(get_xdlop_archs().count(arch) == 0) + return {}; + auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations( + *this, prologue, epilogue); // obtains vector of instances + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { + return op.ToSolution(); // template instance with correct values + }); + return result; +} + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp new file mode 100644 index 0000000000..b12c2e1a4a --- /dev/null +++ b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp @@ -0,0 +1,408 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// calculate appropriate Gemm Specification based on input tensor dimensions +std::string GetGemmSpec(const std::size_t m, + const std::size_t n, + const std::size_t k, + const std::size_t n1, + const std::size_t m_per_block, + const std::size_t n_per_block, + const std::size_t k_per_block, + const std::size_t n1_per_block) +{ + std::string spec = ""; + if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) + spec += "M"; + if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) + spec += "N"; + if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) + spec += "K"; + if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0) + spec += "O"; + if(spec == "") + return "ck::tensor_operation::device::GemmSpecialization::Default"; + + return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; +} + +// function to update prologue/epilogue with user provided operation +void Operation_Xdl_CShuffle::update_prologue(const std::string& pro) +{ + if(!prologue.empty()) + { + this->prologue = pro; + } + else + { + this->prologue = ""; + } +} + +void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) +{ + if(!epilogue.empty()) + { + this->epilogue = epi; + } + else + { + this->epilogue = ""; + } +} + +// accounts for all possible combinations of Row/Col major +static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Xdl_CShuffle::CreateOperations( + const Problem& prob, const std::string& prologue, const std::string& epilogue) +{ + std::vector result; + + std::vector tile_descriptions = { + // clang-format off +// Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| NumGemmK| +// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch| +// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage| +// | | | | | | | | | | | Wave| Wave| Wave| | + { 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1}, + { 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1}, + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1}, + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1}, + { 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, + { 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, + { 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, + { 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, +// Padded fallback kernel + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1}, +// Irregular k + { 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1}, + { 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1}, + { 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1}, + { 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1}, + { 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1}, + // clang-format on + }; + + const std::vector a_block_descriptions = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, +// Padded fallback kernel + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, +// Irregular k + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + // clang-format on + }; + + const std::vector b1_block_descriptions = { + // clang-format off +// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, +// Padded fallback kernel + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, +// Irregular k + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + // clang-format on + }; + + std::vector cshuffle_descriptions = { + // clang-format off +// CShuffle| CShuffle| +// MXdlPerWave| NXdlPerWave| +// PerShuffle| PerShuffle| +// | | + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 8}, + { 1, 4}, + { 1, 8}, + { 1, 4}, +// Padded fallback kernel + { 1, 2}, + { 1, 2}, +// Irregular k + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + // clang-format on + }; + + std::vector c_block_descriptions = { + // clang-format off +// CBlockTransferClusterLengths| CBlockTransfer +// _MBlock_MWaveMPerXdl| ScalarPerVector +// _NBlock_NWaveNPerXdl| _NWaveNPerXdl +// | + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1,16>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1,16>, 8}, + { S<1, 32, 1, 8>, 8}, +// Padded fallback kernel + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, +// Irregular k + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + // clang-format on + }; + + assert(tile_descriptions.size() == a_block_descriptions.size()); + assert(tile_descriptions.size() == b1_block_descriptions.size()); + assert(tile_descriptions.size() == cshuffle_descriptions.size()); + assert(tile_descriptions.size() == c_block_descriptions.size()); + + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b0_block_transfer = a_block_descriptions[i]; // b0 same as a + x.b1_block_transfer = b1_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.B1 = TensorDesc{prob.B1DataType, ToLayout(prob.TransB1)}; + x.C = TensorDesc{prob.CDataType, ToLayout(prob.TransC)}; + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.b1_elem_op = prob.B1ElementOp; + x.c_elem_op = prob.CElementOp; + x.acc_elem_op = prob.AccElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + prob.O, + x.tile_desc.gemm01_m_per_block, + x.tile_desc.gemm0_n_per_block, + x.tile_desc.gemm0_k_per_block, + x.tile_desc.gemm1_n_per_block); + x.update_prologue(prologue); + x.update_epilogue(epilogue); + x.mask_out_upper_triangle = true; + result.push_back(x); + + x.mask_out_upper_triangle = false; + result.push_back(x); + } + return result; +} + +// set up instances when not provided with a problem specification, use default operation values and +// all possible layout combinations +std::vector> +Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue) +{ + Problem prob; + prob.TransA = false; + prob.TransB = true; + prob.TransB1 = false; + prob.TransC = false; + + return {CreateOperations(prob, prologue, epilogue)}; +} + +static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate = + "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${LayoutA}, " + "${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, " + "${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, " + "${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, " + "${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, " + "${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, " + "${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, " + "${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, " + "${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, " + "${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, " + "${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, " + "${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, " + "${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, " + "${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, " + "${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, " + "${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, " + "${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, " + "${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, " + "${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, " + "${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " + "${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, " + "${CBlockTransferScalarPerVector_NWaveNPerXdl}, ${MaskOutUpperTriangle}>"; + +// use hardcoded instances from vector of operations to substitute values into instance template +Solution Operation_Xdl_CShuffle::ToSolution() const +{ + std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.gemm01_m_per_block) + "_" + + std::to_string(this->tile_desc.gemm0_n_per_block) + "_" + + std::to_string(this->tile_desc.gemm0_k_per_block) + "_" + + std::to_string(this->tile_desc.gemm1_n_per_block) + "_" + + std::to_string(this->tile_desc.gemm1_k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.b1k1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)}, + {"LayoutA", ToString(this->A.layout)}, + {"LayoutB0", ToString(this->B.layout)}, + {"LayoutB1", ToString(this->B1.layout)}, + {"LayoutC", ToString(this->C.layout)}, + {"ADataType", ToString(this->A.element)}, + {"B0DataType", ToString(this->B.element)}, + {"B1DataType", ToString(this->B1.element)}, + {"CDataType", ToString(this->C.element)}, + {"AccDataType", ToString(this->acc)}, + {"CShuffleDataType", ToString(this->cs_type)}, + {"AElementwiseOperation", this->a_elem_op}, + {"B0ElementwiseOperation", this->b_elem_op}, + {"Acc0ElementwiseOperation", this->acc_elem_op}, + {"B1ElementwiseOperation", this->b1_elem_op}, + {"CElementwiseOperation", this->c_elem_op}, + {"GemmSpecialization", this->gemm_specialization}, + {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, + {"BlockSize", std::to_string(this->tile_desc.block_size)}, + {"Gemm01MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)}, + {"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)}, + {"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)}, + {"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)}, + {"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)}, + {"AK1", std::to_string(this->tile_desc.ak1)}, + {"BK1", std::to_string(this->tile_desc.bk1)}, + {"B1K1", std::to_string(this->tile_desc.b1k1)}, + {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, + {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, + {"Gemm0MXdlPerWave", std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave)}, + {"Gemm0NXdlPerWave", std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave)}, + {"Gemm1NXdlPerWave", std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)}, + {"ABlockTransferThreadClusterLengths_AK0_M_AK1", + this->a_block_transfer.thread_cluster_length}, + {"ABlockTransferThreadClusterArrangeOrder", + this->a_block_transfer.thread_cluster_arrange_order}, + {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, + {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, + {"ABlockTransferSrcScalarPerVector", + std::to_string(this->a_block_transfer.src_scalar_per_vector)}, + {"ABlockTransferDstScalarPerVector_AK1", + std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, + {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, + {"B0BlockTransferThreadClusterLengths_BK0_N_BK1", + this->b0_block_transfer.thread_cluster_length}, + {"B0BlockTransferThreadClusterArrangeOrder", + this->b0_block_transfer.thread_cluster_arrange_order}, + {"B0BlockTransferSrcAccessOrder", this->b0_block_transfer.src_access_order}, + {"B0BlockTransferSrcVectorDim", std::to_string(this->b0_block_transfer.src_vec_dim)}, + {"B0BlockTransferSrcScalarPerVector", + std::to_string(this->b0_block_transfer.src_scalar_per_vector)}, + {"B0BlockTransferDstScalarPerVector_BK1", + std::to_string(this->b0_block_transfer.dst_scalar_per_vector_k1)}, + {"B0BlockLdsExtraN", std::to_string(this->b0_block_transfer.lds_add_extra_dim)}, + {"B1BlockTransferThreadClusterLengths_BK0_N_BK1", + this->b1_block_transfer.thread_cluster_length}, + {"B1BlockTransferThreadClusterArrangeOrder", + this->b1_block_transfer.thread_cluster_arrange_order}, + {"B1BlockTransferSrcAccessOrder", this->b1_block_transfer.src_access_order}, + {"B1BlockTransferSrcVectorDim", std::to_string(this->b1_block_transfer.src_vec_dim)}, + {"B1BlockTransferSrcScalarPerVector", + std::to_string(this->b1_block_transfer.src_scalar_per_vector)}, + {"B1BlockTransferDstScalarPerVector_BK1", + std::to_string(this->b1_block_transfer.dst_scalar_per_vector_k1)}, + {"B1BlockLdsExtraN", std::to_string(this->b1_block_transfer.lds_add_extra_dim)}, + {"CShuffleMXdlPerWavePerShuffle", + std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, + {"CShuffleNXdlPerWavePerShuffle", + std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, + {"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl", + this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, + {"CBlockTransferScalarPerVector_NWaveNPerXdl", + std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + {"MaskOutUpperTriangle", std::to_string(this->mask_out_upper_triangle)}, + }; + + return Solution{InterpolateString(DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate, values), + std::move(values)}; +} + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp index fff75c1962..fe556615e0 100644 --- a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) // accounts for all possible combinations of Row/Col major static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } +// clang-format off +// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, + +// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +// clang-format on + // Hard-code tuning parameters in modularized fashion, string them together into a vector of // instances std::vector Operation_Xdl_CShuffle::CreateOperations( @@ -83,6 +89,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, +// Irregular tile + { 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1}, // clang-format on }; @@ -100,6 +108,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, // clang-format on }; @@ -109,15 +119,17 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| // Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | // | | | | | | | + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, // clang-format on - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, }; std::vector b_block_descriptions_rowmajor = { @@ -134,6 +146,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, // clang-format on }; @@ -151,6 +165,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, // clang-format on }; @@ -167,6 +183,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { 1, 1}, { 1, 1}, { 1, 1}, + { 1, 1}, { 1, 1}, // clang-format on }; @@ -185,6 +202,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<1, 16, 1, 8>, 8}, { S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8}, +// Irregular tile + { S<1, 16, 1, 4>, 1}, // clang-format on }; @@ -199,33 +218,44 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( assert(tile_descriptions.size() == cshuffle_descriptions.size()); assert(tile_descriptions.size() == c_block_descriptions.size()); - // Put all values together into a single operation > store into the result vector - for(std::size_t i = 0; i < tile_descriptions.size(); i++) + const std::vector> scheduler_pipeline_descriptions = + { + {LoopScheduler::Default, PipelineVersion::v1}, + {LoopScheduler::Interwave, PipelineVersion::v1}, + {LoopScheduler::Default, PipelineVersion::v2}, + }; + for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions) { - Operation_Xdl_CShuffle x; - x.tile_desc = tile_descriptions[i]; - x.a_block_transfer = a_block_descriptions[i]; - x.b_block_transfer = b_block_descriptions[i]; - x.cshuffle = cshuffle_descriptions[i]; - x.c_block_transfer = c_block_descriptions[i]; - x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; - x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; - x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; - x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { - return TensorDesc{dt, ToLayout(trans)}; - }); - x.a_elem_op = prob.AElementOp; - x.b_elem_op = prob.BElementOp; - x.cde_elem_op = prob.CDEElementOp; - x.gemm_specialization = GetGemmSpec(prob.M, - prob.N, - prob.K, - x.tile_desc.m_per_block, - x.tile_desc.n_per_block, - x.tile_desc.k_per_block); - x.update_prologue(prologue); - x.update_epilogue(epilogue); - result.push_back(x); + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b_block_transfer = b_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; + x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { + return TensorDesc{dt, ToLayout(trans)}; + }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + x.tile_desc.m_per_block, + x.tile_desc.n_per_block, + x.tile_desc.k_per_block); + x.loop_scheduler = loop_scheduler; + x.pipeline_version = pipeline_version; + x.update_prologue(prologue); + x.update_epilogue(epilogue); + result.push_back(x); + } } return result; } @@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = "${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, " "${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " - "${CDEBlockTransferScalarPerVector_NPerBlock}>"; + "${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>"; // use hardcoded instances from vector of operations to substitute values into instance template Solution Operation_Xdl_CShuffle::ToSolution() const @@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, {"CDEBlockTransferScalarPerVector_NPerBlock", std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + {"LoopScheduler", ToString(this->loop_scheduler)}, + {"PipelineVersion", ToString(this->pipeline_version)}, }; return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values), diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp index 9aa5d39fae..a60e36ca4a 100644 --- a/codegen/src/types.cpp +++ b/codegen/src/types.cpp @@ -59,6 +59,26 @@ std::string ToString(GemmType gt) throw std::runtime_error("Incorrect gemm type"); } +std::string ToString(LoopScheduler ls) +{ + switch(ls) + { + case LoopScheduler::Default: return "ck::LoopScheduler::Default"; + case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave"; + } + throw std::runtime_error("Incorrect LoopScheduler type"); +} + +std::string ToString(PipelineVersion pv) +{ + switch(pv) + { + case PipelineVersion::v1: return "ck::PipelineVersion::v1"; + case PipelineVersion::v2: return "ck::PipelineVersion::v2"; + } + throw std::runtime_error("Incorrect PipelineVersion type"); +} + std::string SequenceStr(const std::vector& v) { return "ck::Sequence<" + diff --git a/codegen/test/rtc/include/rtc/hip.hpp b/codegen/test/rtc/include/rtc/hip.hpp index af2f4a9122..3163bb08ed 100644 --- a/codegen/test/rtc/include/rtc/hip.hpp +++ b/codegen/test/rtc/include/rtc/hip.hpp @@ -8,6 +8,7 @@ #include #include #include +#include namespace rtc { diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 5746aa2b7b..13a1c30e43 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = gemm_calc( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + float ave_time = + gemm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = @@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C - << " A_Layout =" << ALayout::name - << " B_Layout =" << BLayout::name - << " C_Layout =" << CLayout::name - << " A Type = " << DataTypeTraits::name - << " B Type = " << DataTypeTraits::name - << " C Type = " << DataTypeTraits::name - << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << std::endl; + << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name + << " B Type = " << DataTypeTraits::name + << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; } @@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; - using ADataType = typename GemmBasicTypeConfig::ADataType; - using BDataType = typename GemmBasicTypeConfig::BDataType; - using CDataType = typename GemmBasicTypeConfig::CDataType; - using AccDataType = typename GemmBasicTypeConfig::AccDataType; + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); @@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat); + invoke_gemm( + a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; @@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc, a_m_k, b_k_n, c_m_n_host_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol - (K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_m_n_dev_result, + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), @@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc, c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); const float max_accumulated_value = *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol - (K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_m_n_dev_result, + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index bfbcebd7c8..ea5a5d0e16 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -610,6 +610,96 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return true; } + static constexpr bool + IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_) + { + // check vector load/store + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v) + { + if(KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of B + if constexpr(is_same_v) + { + if(NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of B1 + if constexpr(is_same_v) + { + if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of C + if constexpr(is_same_v) + { + if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + + return true; + } + static bool IsSupportedArgument(const Argument& arg) { if(!ck::is_xdl_supported()) @@ -624,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; - // Check scalar per vector requirement - const auto a_extent_lowest = - is_same_v ? KRaw : MRaw; - const auto b_extent_lowest = - is_same_v ? NRaw : KRaw; - const auto b1_extent_lowest = - is_same_v ? Gemm1NRaw : NRaw; - const auto c_extent_lowest = - is_same_v ? Gemm1NRaw : MRaw; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - return false; - } - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + arg.block_2_ctile_map_) and + IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw); } // polymorphic @@ -764,6 +837,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return str.str(); } + + template + struct Descriptor + { + template + static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc) + { + const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc) + { + const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc) + { + const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc); + + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc) + { + return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc); + } + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using B1GridDesc_BK0_N_BK1 = + remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + matrix_padder.PadN, + MaskOutUpperTriangle>; + + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1; + CGridDesc_M_N c_grid_desc_m_n; + C0MatrixMask c0_matrix_mask; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_descriptor_mblock_mperblock_nblock_nperblock; + + // element-wise op + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + B1ElementwiseOperation b1_element_op; + CElementwiseOperation c_element_op; + + bool has_main_k_block_loop = true; + bool is_valid = false; + + constexpr Descriptor(ADesc a, + BDesc b, + B1Desc b1, + CDesc c, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + B1ElementwiseOperation b1_element_op_, + CElementwiseOperation c_element_op_) + : a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)}, + b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, + b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, + c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, + block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)}, + c_grid_descriptor_mblock_mperblock_nblock_nperblock{ + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n)}, + has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( + a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, + c0_matrix_mask{c.GetLength(I1)}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + b1_element_op{b1_element_op_}, + c_element_op{c_element_op_}, + is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_m_n, + block_2_ctile_map) and + IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), + b_grid_desc_bk0_n_bk1.GetLength(I1), + a_grid_desc_ak0_m_ak1.GetLength(I0) * + a_grid_desc_ak0_m_ak1.GetLength(I2), + b1_grid_desc_bk0_n_bk1.GetLength(I1))} + { + } + + constexpr bool IsValid() const { return is_valid; } + }; + + template + static constexpr auto + make_descriptor(ADesc a, + BDesc b, + B1Desc b1, + CDesc c, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{}, + CElementwiseOperation c_element_op = CElementwiseOperation{}) + { + return Descriptor( + a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op); + } + + template + __device__ static void Run(const Desc& desc, + const float scale, + const ADataType* __restrict__ p_a_grid, + const ADataType* __restrict__ p_b_grid, + const ADataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid) + { +#ifndef __HIPCC_RTC__ + assert(desc.is_valid); +#endif + __shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()]; + AccElementwiseOperation acc_element_op{scale}; + + if(desc.has_main_k_block_loop) + { + Desc::GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + acc_element_op, + desc.b1_element_op, + desc.c_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.b1_grid_desc_bk0_n_bk1, + desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, + desc.block_2_ctile_map, + desc.c0_matrix_mask); + } + else + { + Desc::GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + acc_element_op, + desc.b1_element_op, + desc.c_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.b1_grid_desc_bk0_n_bk1, + desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, + desc.block_2_ctile_map, + desc.c0_matrix_mask); + } + } }; } // namespace device From 660db601844d439563a7db0cb27f4bf4fab794aa Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 11 Feb 2025 09:24:03 -0800 Subject: [PATCH 4/7] replace docker credentials (#1881) --- Jenkinsfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 835b7e724f..80392bfbed 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -117,7 +117,7 @@ def getDockerImage(Map conf=[:]){ { echo "Pulling down image: ${image}" retimage = docker.image("${image}") - withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.pull() } } @@ -148,7 +148,7 @@ def buildDocker(install_prefix){ //force building the new docker if that parameter is true echo "Building image: ${image_name}" retimage = docker.build("${image_name}", dockerArgs) - withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.push() } sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi' @@ -162,7 +162,7 @@ def buildDocker(install_prefix){ catch(Exception ex){ echo "Unable to locate image: ${image_name}. Building image now" retimage = docker.build("${image_name}", dockerArgs + ' .') - withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.push() } } From 8086bbe3a78d931eb96fe12fdc014082e18d18d3 Mon Sep 17 00:00:00 2001 From: Andres Lugo <108368282+alugorey@users.noreply.github.com> Date: Tue, 11 Feb 2025 12:11:46 -0600 Subject: [PATCH 5/7] Add receipt 4 option to codegen (#1875) * Add receipt 4 option to codegen * Remove repeated code * Review comments --- example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 10 +++++++++- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 9 ++++++++- example/ck_tile/01_fmha/generate.py | 3 ++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 83a1e82d6d..c05660c8ab 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -506,6 +506,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= deterministic == "f" if not cond: continue + if receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'bias'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) @@ -801,4 +809,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 1c9d743f3d..ad8daba17e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -487,13 +487,20 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue - if receipt == 2: + if receipt in (2, 3): cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' if not cond: continue + if receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 5b1b6664cc..a0fb42aa11 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -103,7 +103,8 @@ if __name__ == "__main__": required=False, help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ " 1: generate more instance to cover all hdim\n" + \ - " 2: Only generate instance for Flash attention integration" + " 2: Only generate instance for Flash attention integration\n" + \ + " 4: Only generate instance for PyTorch integration" ) args = parser.parse_args() From 78195cccad673825f046523e84c503de7e741ef1 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 11 Feb 2025 13:26:11 -0800 Subject: [PATCH 6/7] add -Wno-unique-object-duplication compiler option (#1882) --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1fe1bc91d5..e90f893de0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,7 @@ endif() add_compile_options(-Wno-bit-int-extension) add_compile_options(-Wno-pass-failed) add_compile_options(-Wno-switch-default) +add_compile_options(-Wno-unique-object-duplication) if(DL_KERNELS) add_definitions(-DDL_KERNELS) From 3c7fef7f80ebefda76361b3f87868d91ff39e5b7 Mon Sep 17 00:00:00 2001 From: JonathanLichtnerAMD Date: Tue, 11 Feb 2025 17:25:00 -0700 Subject: [PATCH 7/7] Conditionally log a DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle warning (#1860) The code was emitting a warning if MIOpen did not create a workspace prior to invoking the IsSupportedArgument method, but the condition for MIOpen to create a workspace was not met, and so this condition was not really an error but more of a log message. This commit addresses this issue by using the CK_LOGGING facility to only generate the log message if the CK_LOGGING environment variable is set. --- ...grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index b4cf996a48..795995d9a3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -1495,10 +1495,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle // if workspace is not allocated if(!arg.p_workspace_) { - std::cerr << "Warning: Workspace for " - "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not " - "allocated, use SetWorkSpacePointer." - << std::endl; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Warning: Workspace for " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } return false; } if(!ck::is_xdl_supported())