From 7cb1f30cfb6045bccbbd484c5e8e4715e2ebc2f3 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Wed, 1 Oct 2025 09:02:21 -0500 Subject: [PATCH 001/172] Remove default constructor to fix c++17 build issue (#2953) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Remove default constructor to fix build issue * Restore default CTOR, remove constexpr, add init --------- Co-authored-by: Bartłomiej Kocot --- include/ck/tensor_description/multi_index_transform.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index e24227ecc3..ecc3dcf4a0 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1586,7 +1586,7 @@ struct ConvBwdDataImplicitGemmOutTransform Tuple low_lengths_magic_divisor_shift_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_ - __host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform() = default; + __host__ __device__ ConvBwdDataImplicitGemmOutTransform() = default; __host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N, index_t Ho, @@ -1645,7 +1645,7 @@ struct ConvBwdDataImplicitGemmOutTransform template __host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const { - index_t NStep, HStep, WStep; + index_t NStep{0}, HStep{0}, WStep{0}; // Merge // NStep = M_id / TildeSlice_ NStep = MagicDivision::DoMagicDivision(idx_up[I1], From 190ad2cceeb28b553033393cdbdf0453108ac64a Mon Sep 17 00:00:00 2001 From: Mingtao Gu <145657261+mtgu0705@users.noreply.github.com> Date: Thu, 2 Oct 2025 03:32:55 +0800 Subject: [PATCH 002/172] updated mxfp4 moe gemm2 config (#2330) Co-authored-by: mtgu0705 From a76c7b10281cf46486e6563ffeb3ee9cb4a20348 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Wed, 1 Oct 2025 15:00:41 -0700 Subject: [PATCH 003/172] tweak version (#2954) --- pyproject.toml | 4 ++-- python/ck4inductor/__init__.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1e05a50af8..e8868ed92d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools", "setuptools-scm"] +requires = ["setuptools"] build-backend = "setuptools.build_meta" [project] @@ -36,4 +36,4 @@ ck4inductor = "python/ck4inductor" "ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp", "src/tensor_operation_instance/gpu/gemm_universal_batched/**/*.hpp", "include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/**/*.hpp"] [tool.setuptools.dynamic] -version = { attr = "setuptools_scm.get_version" } +version = { attr = "ck4inductor.__version__" } diff --git a/python/ck4inductor/__init__.py b/python/ck4inductor/__init__.py index e69de29bb2..ac44aeb777 100644 --- a/python/ck4inductor/__init__.py +++ b/python/ck4inductor/__init__.py @@ -0,0 +1,19 @@ +def __version__(): + import subprocess + + # needs to be manually updated + rocm_version = "7.0.1" + hash_width = 6 + try: + hash = subprocess.check_output("git rev-parse HEAD", shell=True, text=True)[ + :hash_width + ] + except: + hash = "0" * hash_width + try: + change_count = subprocess.check_output( + f"git rev-list rocm-{rocm_version}..HEAD --count", shell=True, text=True + ).strip() + except: + change_count = "0" + return f"{rocm_version}.dev{change_count}+g{hash}" From f2d367262fa278403aa2ed760f169144c9850a81 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Wed, 1 Oct 2025 18:22:46 -0400 Subject: [PATCH 004/172] tests: add unit tests for grouped_gemm_multi_d persistent kernels (#2941) * feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * style: clang format * WIP: host code for grouped_gemm_multi_d persistent kernel compiles but segfaults * feat(grouped_gemm_multi_d): add functionality to run persistant kernel * feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * style: clang format * fix: incorrect validation method and Dtensor layout in test suite * tests: add unit tests for grouped_gemm_multi_d persistent kernels * parent 5b0af640369b93849335b126d6826b204ccc43a3 author AviralGoelAMD 1758919991 +0000 committer AviralGoelAMD 1759338256 +0000 docs: updated changelog with new feature info fix wp gemm bug when permuteN is false (#2935) * fix wp gemm bug when permuteN is false * code clean --------- Co-authored-by: valarLip <340077269@qq.com> fix copy-paste bug in get_matrix_b; re-enable all tests in multi_abd (#2939) [CK_TILE] FMHA Fix synchronization issue in FWD splitkv combine pipeline (#2934) * Fix validation of rotary embedding with time_kernel_ When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times when time_kernel_ is set). We need to reset the q buffer and rerun all kernels. * Fix synchronization issue in splitkv combine pipeline Different warps can read and then rewrite the same values of lse_acc_lds. Sometimes warps progress at different speeds, one warp can rewrite values that are still being read by another warp. Running the tests multiple times and, preferably, with multiple processes on the same GPU helps to trigger this issue: bin/test_ck_tile_fmha_fwd_fp16 --gtest_repeat=-1 --gtest_shuffle --gtest_throw_on_failure --gtest_filter="TestCkTileFmhaFwd/*KV*" [CK_TILE] Support f32 in FMHA (fwd and bwd) (#2836) * Support 16x16 (MFMA, WMMA) and 32x32 (MFMA) tiles in fwd and bwd BlockDropout Add comments with dropout implementation details Fix performance regression of fwd+dropout * Remove some usage of type punning (reinterpret_cast with ref or ptr) in Philox; * "scalarize" seed and offset, they may come either from kernel args or from device memory (presumably loaded with vector loads). These changes help the compiler to procude more optimal code and reduce register spilling. Use WarpGemmDispatcher instead of explicit WarpGemmMfma... to get CWarpDstrEncoding Use code based on BlockDropout in BlockDropoutBwd Refactor BlockDropout (fwd) Implement BlockDropout (fwd) for WMMA Originally BlockDropout only supported 32x32 tiles (IsWG32 = true), this version supports 16x16 tiles. If MPerBlock > MWarp * 16, it can generate numbers for two 16x16 tiles, similarly to BlockDropoutBwd. Implement BlockDropoutBwd for WMMA Remove MakeRandValLds* functions unused in BlockDropoutBwd Remove unused Run overload from BlockDropoutBwd * Fix regression with philox seed and offset when they exceed 32-bit int __builtin_amdgcn_readfirstlane works with 32-bit values, seed and offset are 64-bit so they get truncated. * Add F32 MFMA warp gemms * Support f32 in fwd FMHA * Implement transpose_vectors for 4-byte types (float) * Fix unexpected implicit f32->uint32 cast in buffer_store<4> __builtin_amdgcn_raw_buffer_store_b32 expects unsigned int but float was passed (implicitly casted to uint). mbuf_t types in other buffer_store<> are changed for consistency. * Support F32 in bwd FMHA hdim = 256 is disabled for now because it uses too much memory on gfx90a * Support Headdim = 48 (divisible by 16) in fwd * Add fp32-specific receipts (800 and 801) * Tune fwd tiles * Tune bwd tiles * Use small tiles only for small seqlen_q * Fix after rebasing * Fix selection of a fallback tile based on bm0 The assumption that the largest bm0 == 128 is not always true for current fp32 tiles. * Remove constraints and adjust filtering for fp32 Custom constraints are no longer needed because now the smallest tile is selected automtically based on seqlen_q. Filters related to qr_async_trload disabled valid fp32 tiles. * Add fp32 tests * Make splitkv and appendkv compile for fp32 only There are no instances yet, but API still must compile when only fp32 is requested. * Remove unimportant f32 instances * Add test_ck_tile_fmha_*_fp32 to REGRESSION_TESTS * Replace magic numbers with a constant, improve comments for dropout * Update changelog * Fix condition that dq_acc must be set to zero when mask is used The change was introduced in #2799 * Replace warp_uniform with recently added amd_wave_read_first_lane * Add hdim = 96 and 192 to fwd Use git ls-files to select candidate files for clang format This change ensures that the files being selected for clang format validation are exactly the ones tracked by the git repo we are testing. This protects against an known issue where the repo being tested contained "stray files" from a previous test. [CK_TILE] Fixing Type Conversions in PassThroughPack8 (#2769) * Change the return type of run_gemm_combinations in the basic tests * Change the return type of run_gemm_combinations in the universal tests * Add universal GEMM tests for bf16 x pk_i4 and fp16 x pk_i4 * Add universal GEMM test for fp8 x pk_i4 * Add basic GEMM tests for bf16 x pk_i4, fp16 x pk_i4 and fp8 x pk_i4. * Add missing GemmTypeConfig * Add missing GemmTypeConfig * No need for utility in test_ck_tile_elementwise_1d * Fix conversion from pk_int4x4_t to bf16x8_t in PassThroughPack8 * Avoid union-based type punning in float_to_bf16_truc_raw to make it constexpr compliant * For consistency also make float_to_bf16_truc_nan_raw constexpr compliant by removing the union * Use a static_cast to bfloat16_t only when CK_TILE_USE_LLVM_BUILTIN_BF16 is enforced * Convert from float to bf16 during compilation rather than using magic values * Fix conversion from pk_int4x4_t to fp8x8_t in PassThroughPack8 * Comment out the basic test for fp16 x pk_i4 as it does not pass * Add missing GemmTypeConfig * Fix conversion from pk_int4x4_t to bf8x8_t in PassThroughPack8 * Add basic and universal GEMM tests for bf8 x pk_i4 * Switch back to amd_assembly_i4_to_fp8x8 in PassThroughPack8 as it works now * Switch back to amd_assembly_i4_to_bf8x8 in PassThroughPack8 as it works now * Remove the inefficient fallbacks for fp8 and bf8 in elementwise/unary_element_wise_operation.hpp * Use explicit macros for enabling and disabling the the constexpr lookup based converters * Fix two failing tests * Avoid union-based type punning in float_to_bf16_rtn_raw to make it constexpr compliant * Use float_to_bf16_rtn_raw instead of float_to_bf16 to create the bf16 lookup table for use in conversions from pk_int4 to bf16 * On ROCm 7.0.1 we need an explicit cast to from uint16_t to bf16_t Grouped Conv Bwd Data out index calculation optimizations (#2917) * Grouped Conv Bwd Data index calculation optimizations * fixes * refactor instances * gfx12 fixes * temporary disable splitK for gfx12 [CK] Fix example_grouped_conv_bwd_data_xdl_fp16 with ksplit = 2 (#2943) root cause: AK1 and BK1 may different in class template. so we need calculate k0 per block separately when ksplit is not 1. [CK][Examples] Extending support for rdna3/4 in following examples: (#2884) * [CK][Examples] Extending support for rdna3/4 in following examples: -example_gemm_xdl_splitk_reduce_multi_d_fp16 -example_gemm_xdl_splitk_reduce_multi_d_bf16 -example_gemm_xdl_splitk_reduce_bf16A_i8B -example_gemm_xdl_splitk_reduce_bfp16 -example_splitk_gemm_bias_e_permute_xdl_fp32 -example_gemm_add_multiply_xdl_fp16 -example_complex_contraction_bilinear_xdl_fp32 -example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 -example_batched_gemm_bias_e_permute_xdl_fp16 -example_gemm_xdl_fp16 -example_gemm_xdl_fp16_av2 -example_gemm_xdl_wavelet_fp16 -example_gemm_add_add_fastgelu_xdl_bf16 -example_gemm_add_add_fastgelu_xdl_fp16 -example_gemm_add_add_fastgelu_xdl_fp32 -example_grouped_gemm_xdl_fp32 -example_grouped_gemm_xdl_fp16 -example_grouped_gemm_xdl_bf16 -example_cgemm_xdl_bf16 -example_cgemm_xdl_fp16 Signed-off-by: Michal Kulikowski * [CK][Examples] Extending support for rdna3/4 in following examples: -example_gemm_xdl_splitk_reduce_multi_d_fp16 -example_gemm_xdl_splitk_reduce_multi_d_bf16 -example_gemm_xdl_splitk_reduce_bf16A_i8B -example_gemm_xdl_splitk_reduce_bfp16 -example_splitk_gemm_bias_e_permute_xdl_fp32 -example_gemm_add_multiply_xdl_fp16 -example_complex_contraction_bilinear_xdl_fp32 -example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 -example_batched_gemm_bias_e_permute_xdl_fp16 -example_gemm_xdl_fp16 -example_gemm_xdl_fp16_av2 -example_gemm_xdl_wavelet_fp16 -example_gemm_add_add_fastgelu_xdl_bf16 -example_gemm_add_add_fastgelu_xdl_fp16 -example_gemm_add_add_fastgelu_xdl_fp32 -example_grouped_gemm_xdl_fp32 -example_grouped_gemm_xdl_fp16 -example_grouped_gemm_xdl_bf16 -example_cgemm_xdl_bf16 -example_cgemm_xdl_fp16 Signed-off-by: Michal Kulikowski --------- Signed-off-by: Michal Kulikowski hot fix check eid range (#2924) * hot fix check eid range * fix clang format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin_amdeng Weight Preshuffle Block Scale gemm support (#2877) * initial commit * remove extra files * fixing errors * updated ReadMe file for mapping of diff quants with diff configs * addressing review comments * addressing review comments * Resolved merge conflicts * [CK TILE GEMM] Replace get_preshuffle_or with is_quantpreshuffle_enabled The get_preshuffle_or was not working as expected, which led to incorrect behavior in the quantization preshuffle process. This change replaces it with the more reliable is_quantpreshuffle_enabled function to properly determine when preshuffle should be applied. * initial commit * debugging * working fp8 for init constant * fp8 working with all inits * updated block level code with comments * changing the loop iter * debugging * debugging * debugging * code fix * code clean up * clang formatted * Add comment * code cleanup * clang formatted * merge conflicts fixes * applying the latest int4 changes to the piepline * fixing test code for updated traits * Adding gtest * review comments addressed * addressing review comments * remove c++20 code * added flush cache changes --------- Co-authored-by: Cong Ma Co-authored-by: root increase time limit for AITER tests (#2948) Code style clean-up and documentation The following changes were made: - Clean-up of variable namings - Addition of README - Removal of num_cu and occupancy args; such options are meant for testing purposes and should not be exposed to the user - Removal of CK_TILE_PIPELINE_MEMORY macro and PipelineTypeTraits class since we only support one pipeline at the moment. Fix timing issue in CK_TILE GEMM example (#2940) * feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * WIP: host code for grouped_gemm_multi_d persistent kernel compiles but segfaults * feat(grouped_gemm_multi_d): add functionality to run persistant kernel * fix: parameterize NumDTensor in GroupedGemmHostArgs and remove lint Fix timing issue in CK_TILE GEMM example (#2940) * style: clang format * refactor: removed unused file [CK] Add command option instance_index and param_mask to run partial ck test (#2889) * [CK] Add command option instance_index and param_mask to run partial ck test Many CK test are instance test. it will loop all instance in the instance library. It causes test often out-of-time if we run test on simulator/emulator. This PR add option instance_index and param_mask to reduce the workload of instance test instance_index: only run test 1 available instance with specified index. param_mask: filter the embedded parameter with specified mask * fix CI error * fix clang format --------- Co-authored-by: illsilin_amdeng [CK_TILE]enhance elementwise test (#2683) * enhance elementwise * fix ci issues --- CHANGELOG.md | 1 + .../17_grouped_gemm/grouped_gemm_multi_d.hpp | 2 + .../run_grouped_gemm_multi_d_example.inc | 10 +- .../test_grouped_gemm_multi_d.cpp | 15 +- .../test_grouped_gemm_multi_d_util.hpp | 135 +++++++++++++++++- 5 files changed, 149 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 438320d907..be613fb78c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added the new api to load different memory sizes to SGPR. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. +* Added support for grouped_gemm kernels to perform multi_d elementwise operation. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data * Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. * Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index d5203a799c..0789452ada 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -76,6 +76,7 @@ struct GemmConfigMemory : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 8; static constexpr bool DoubleSmemBuffer = false; + static constexpr bool Persistent = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; @@ -116,6 +117,7 @@ struct GemmConfigV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr bool Persistent = true; static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index 8f275b069b..e1647c037b 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -182,9 +182,9 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, << std::endl; for(int i = 0; i < group_count; i++) { - Ms.push_back(256 /* + 256 * i */); - Ns.push_back(256 /* + 512 * i */); - Ks.push_back(64 /* + 384 * i */); + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 384 * i); stride_As.push_back(Ks[i]); stride_Bs.push_back(Ks[i]); @@ -256,8 +256,8 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); - ck_tile::FillUniformDistribution{2.f, -2.f}(d0_m_n_tensors[i]); - ck_tile::FillUniformDistribution{2.f, -2.f}(d1_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors[i]); a_m_k_dev_buf.push_back(std::make_unique(a_m_k_tensors[i])); diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp index deea2fc852..c6356a6b2c 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp @@ -31,7 +31,8 @@ template + PipelineType Pipeline_val_, + bool Persistent_val_> struct KernelConfig { using ALayoutType = ALayout_; @@ -56,15 +57,19 @@ struct KernelConfig static constexpr bool DoubleSmemBuffer_ = DoubleSmemBuffer_val_; static constexpr auto Scheduler_ = Scheduler_val_; static constexpr PipelineType Pipeline_ = Pipeline_val_; + static constexpr bool Persistent_ = Persistent_val_; static constexpr int BlockPerCu_ = 1; }; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory>, // memory - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3>, // v3 - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4> // v4 + // ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline, Persistent + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4 >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp index 4c13b4a7f7..30a61a081b 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -93,7 +93,6 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } - template void invoke_grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr) @@ -229,6 +228,100 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } + void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) + { + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + using GemmUniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = std::conditional_t< + Config::Pipeline_ == (PipelineType::Memory), + ck_tile::GemmPipelineAgBgCrMem, + std::conditional_t, + ck_tile::GemmPipelineAgBgCrCompV4>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + + return ave_time; + }; + if(!splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } + } + public: void Run(const std::vector& Ms, const std::vector& Ns, @@ -379,9 +472,43 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(gemm_descs)); - invoke_grouped_gemm(gemm_descs, - ck_tile::stream_config{nullptr, false, 1}, - gemm_workspace.GetDeviceBuffer()); + if constexpr(Config::Persistent_) + { + std::vector> kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + const bool splitk = gemm_descs[0].k_batch > 1; + for(const auto& arg : gemm_descs) + { + kargs.emplace_back( + ck_tile::UniversalGemmKernelArgs<1, 1, DsDataType::size()>{{arg.a_ptr}, + {arg.b_ptr}, + arg.ds_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + arg.stride_Ds, + arg.stride_E, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + ck_tile::hip_check_error(hipMemcpyWithStream( + kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + + invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr, splitk); + } + else + { + invoke_grouped_gemm(gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer()); + } // Copy results back to host for validation for(int i = 0; i < group_count; i++) From a7da3c68b979bd46c315da09208271d26f5e2900 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Wed, 1 Oct 2025 15:38:07 -0700 Subject: [PATCH 005/172] Add a new gemm pipeline based on ComputeV4 which utilizes async copy API (#2949) * check in pipeline and policy for async load in mi350, need to make sure TileAccessPattern is warp_raked or block_raked solve merge conflicts * fix cmakelists * make it build * fix? buffer async fence * relax fences; it appears it only is needed between pairs of ping-pongs * remove fences * remove fences * cleanup and reformat * add steps annotations * comment all pipeline steps / remove unexplainable syncs * clang-format * add comment * cleanup kernel types for test * fix comment * fix hardcoded warp size * faithfully copy block gemm from compute v4 policy to async policy * make async test gfx950 only * fix cmake logic * set separate compile options for async * refine comment in policy * try update hotloop scheduler * cleanup comments * test more K block sizes * unhardcode Ks, sort of * add large odd test case * fix build for quant * add comment to hot loop scheduler and rename enum * reformat * reword the pipeline description * reformat * address review / add static asserts / typo fix * update changelog --- CHANGELOG.md | 1 + include/ck_tile/core/arch/arch.hpp | 16 + include/ck_tile/ops/gemm.hpp | 2 + .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 531 ++++++++++++++++++ ...ine_ag_bg_cr_comp_async_default_policy.hpp | 101 ++++ ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 108 ++-- .../gemm_aquant_pipeline_ag_bg_cr_policy.hpp | 3 - .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 3 - test/ck_tile/gemm/CMakeLists.txt | 6 + .../gemm/test_gemm_pipeline_comp_async.cpp | 17 + .../gemm/test_gemm_pipeline_kernel_types.hpp | 11 +- .../gemm/test_gemm_pipeline_ut_cases.inc | 51 +- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 15 +- 13 files changed, 803 insertions(+), 62 deletions(-) create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_comp_async.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index be613fb78c..9aadc3dc54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.0.0 ### Added +* Added a compute async pipeline in the CK TILE universal GEMM on gfx950 * Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. * Added the new api to load different memory sizes to SGPR. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 28ded5439a..3b12cf061b 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -275,4 +275,20 @@ CK_TILE_DEVICE static constexpr auto get_device_arch() return gfx12_t{}; #endif } + +enum LLVMSchedGroupMask : int32_t +{ + NONE = 0, + ALU = 1 << 0, + VALU = 1 << 1, + SALU = 1 << 2, + MFMA = 1 << 3, + VMEM = 1 << 4, + VMEM_READ = 1 << 5, + VMEM_WRITE = 1 << 6, + DS = 1 << 7, + DS_READ = 1 << 8, + DS_WRITE = 1 << 9, + ALL = (DS_WRITE << 1) - 1, +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 6e07dbc00e..5edde31cd9 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -40,6 +40,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp new file mode 100644 index 0000000000..2c8d008127 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -0,0 +1,531 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BaseGemmPipelineAgBgCrCompAsync +{ + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::Three; + } + else + { + return TailNumber::Two; + } + } + + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + // Handle all the valid cases. + if(has_hot_loop) + { + if(tail_number == TailNumber::Three) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Two) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + else + { + if(tail_number == TailNumber::Three) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Two) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + // If execution reaches here, it's an invalid tail_number because it wasn't handled above. +#if defined(__HIP_DEVICE_COMPILE__) + __builtin_unreachable(); +#else + throw std::logic_error( + "Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported"); +#endif + } +}; + +/** + * @brief Compute optimized pipeline version async; which is based on V4. + * + * This pipeline introduces asynchronous load from global memory to LDS, + * skipping the intermediate loading into pipeline registers. + */ +template +struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync +{ + using Base = BaseGemmPipelineAgBgCrCompAsync; + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + static_assert(!std::is_same_v, "Not implemented"); + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using BlockGemm = remove_cvref_t())>; + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + static constexpr index_t Preshuffle = Problem::Preshuffle; + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Policy::template IsTransposeC(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); + constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); + constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); + + constexpr index_t WaveSize = get_warp_size(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num; + constexpr auto num_issue = num_buffer_load_inst; + + static_for<0, num_buffer_load_inst, 1>{}([&](auto i) { + // TODO: this will likely need to be redesigned after (1) changes to reading from + // LDS and (2) re-profiling + ignore = i; + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA : 1 + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_READ, 1, 0); // DS read : 1 + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA: 1 + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read :1 + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::MFMA, C_MFMA_Inst_Num / num_issue - 2, 0); // MFMA : 6 + }); + __builtin_amdgcn_sched_barrier(0); + } + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + // TODO support multi-ABD + static_assert(1 == std::tuple_size_v); + static_assert(1 == std::tuple_size_v); + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + // TODO currently fused elementwise are not supported + ignore = a_element_func; + ignore = b_element_func; + static_assert(std::is_same_v, + element_wise::PassThrough>); + static_assert(std::is_same_v, + element_wise::PassThrough>); + static_assert( + std::is_same_v> && + std::is_same_v>, + "Data Type conflict on A and B matrix input data type."); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + // TODO currently only support A matrix row major, B matrix col major; if A matrix is + // col major or B is row major, need to combine with transpose load api + static_assert(!(is_a_col_major || is_b_row_major), + "only support A matrix is row major, B matrix is col major!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + ////////////// global window & register ///////////////// + // A DRAM tile window(s) for load + auto a_tile_windows = generate_tuple( + [&](auto idx) { + return make_tile_window( + a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeADramTileDistribution()); + }, + number{}); + // B DRAM window(s) for load + auto b_tile_windows = generate_tuple( + [&](auto idx) { + return make_tile_window( + b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeBDramTileDistribution()); + }, + number{}); + + // this pipeline has a pair of LDS buffers per logical tile + auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); + auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); + + // LDS tile windows for storing, one per LDS buffer + auto a_copy_lds_window0 = make_tile_window( + a_lds_block0, make_tuple(number{}, number{}), {0, 0}); + + auto a_copy_lds_window1 = make_tile_window( + a_lds_block1, make_tuple(number{}, number{}), {0, 0}); + + auto b_copy_lds_window0 = make_tile_window( + b_lds_block0, make_tuple(number{}, number{}), {0, 0}); + + auto b_copy_lds_window1 = make_tile_window( + b_lds_block1, make_tuple(number{}, number{}), {0, 0}); + + // initialize DRAM window steps, used to advance the DRAM windows + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + + // read A(0), B(0) from DRAM to LDS window(0) + // and advance the DRAM windows + Base::GlobalPrefetchAsync( + a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step); + Base::GlobalPrefetchAsync( + b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step); + + // initialize block gemm + auto block_gemm = BlockGemm(); + + // initialize C block tile + auto c_block_tile = block_gemm.MakeCBlockTile(); + clear_tile(c_block_tile); + + // read A(1), B(1) from DRAM to LDS window(1) + // and advance the DRAM windows + Base::GlobalPrefetchAsync( + a_copy_lds_window1, a_tile_windows[number<0>{}], a_dram_tile_window_step); + Base::GlobalPrefetchAsync( + b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step); + + constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + // register tiles; double buffering -> a register tile corresponds to a LDS tile window + ALdsTile a_block_tile0; + ALdsTile a_block_tile1; + + BLdsTile b_block_tile0; + BLdsTile b_block_tile1; + + // LDS tile windows for reading; + // they share the data pointer with the LDS windows for storing + // but also associate with a distribution to produce a register tile when reading + auto a_lds_ld_window0 = + make_tile_window(a_lds_block0, + make_tuple(number{}, number{}), + {0, 0}, + ALdsTileDistr); + auto a_lds_ld_window1 = + make_tile_window(a_lds_block1, + make_tuple(number{}, number{}), + {0, 0}, + ALdsTileDistr); + auto b_lds_ld_window0 = + make_tile_window(b_lds_block0, + make_tuple(number{}, number{}), + {0, 0}, + BLdsTileDistr); + auto b_lds_ld_window1 = + make_tile_window(b_lds_block1, + make_tuple(number{}, number{}), + {0, 0}, + BLdsTileDistr); + + static_assert(!(is_tile_window_linear_v) && + !(is_tile_window_linear_v) && + !(is_tile_window_linear_v) && + !(is_tile_window_linear_v), + "LDS windows must not be linear"); + + // write to LDS window(0) must complete before the local prefetch + block_sync_lds_direct_load(); + // read A(0), B(0) from LDS window(0) to pipeline registers(0) + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); + // LDS window(0) contents are overwritten below by global prefetch, need to sync + block_sync_lds(); + // read A(2), B(2) from DRAM to LDS window(0) + // and advance the DRAM windows + Base::GlobalPrefetchAsync( + a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step); + Base::GlobalPrefetchAsync( + b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step); + + if(HasHotLoop) + { + // we have had 3 global prefetches so far, indexed (0, 1, 2). + index_t i_global_read = amd_wave_read_first_lane(3); + // alternate ping: (read to register tile(1), use register tile(0) as gemm input) + // pong: (read to register tile(0), use register tile(1) as gemm input) + do + { + // ping + { + // read A(i-1), B(i-1) from LDS window(1) to pipeline registers(1) + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + // LDS window(1) contents are overwritten by global prefetch, need to sync + block_sync_lds(); + // read A(i), B(i) from DRAM to LDS window(1) + // and advance the DRAM windows + Base::GlobalPrefetchAsync(a_copy_lds_window1, + a_tile_windows[number<0>{}], + a_dram_tile_window_step); + Base::GlobalPrefetchAsync(b_copy_lds_window1, + b_tile_windows[number<0>{}], + b_dram_tile_window_step); + // C(i-3) = A(i-3) @ B(i-3) + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + HotLoopScheduler(); + } + // pong + { + // write to LDS window(0) must complete before the local prefetch + block_sync_lds_direct_load(); + // read A(i), B(i) from LDS window(0) to pipeline registers(0) + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); + // LDS window(0) contents are overwritten by global prefetch, need to sync + block_sync_lds(); + // read A(i+1), B(i+1) from DRAM to LDS window(0) + // and advance the DRAM windows + Base::GlobalPrefetchAsync(a_copy_lds_window0, + a_tile_windows[number<0>{}], + a_dram_tile_window_step); + Base::GlobalPrefetchAsync(b_copy_lds_window0, + b_tile_windows[number<0>{}], + b_dram_tile_window_step); + // C(i-2) = A(i-2) @ B(i-2) + block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + HotLoopScheduler(); + } + i_global_read += 2; + } while(i_global_read < num_loop); + } + + // 3 block gemms remaining + if constexpr(TailNum == TailNumber::Three) + { + { + // read A(num_loop-1), B(num_loop-1) from LDS window(1) to pipeline registers(1) + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + // C(num_loop-2) = A(num_loop-2) @ B(num_loop-2) + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + } + { + // read A(num_loop), B(num_loop) from LDS window(0) to pipeline registers(0) + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); + // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) + block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + } + { + // C(num_loop) = A(num_loop) @ B(num_loop) + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + } + } + else + // 2 block gemms remaining + { + { + // read A(num_loop), B(num_loop) from LDS window(1) to pipeline registers(1) + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + } + { + // C(num_loop) = A(num_loop) @ B(num_loop) + block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + } + } + return c_block_tile; + } + }; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_0, + void* p_smem_1) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem_0, + p_smem_1); + } + + public: + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem_0, + p_smem_1); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp new file mode 100644 index 0000000000..23104375d6 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { +// Default policy for GemmPipelineAgBgCrCompAsync +// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor +// GetBlockGemm implementation is copied from GemmPipelineAgBgCrCompV4DefaultPolicy +struct GemmPipelineAgBgCrCompAsyncDefaultPolicy + : public UniversalGemmBasePolicy +{ + static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked; + static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + return transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackB(); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + return transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + constexpr index_t vector_size = + DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); + constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); + constexpr auto wg_attr_num_access = + !(is_a_load_tr || is_b_load_tr) ? WGAttrNumAccessEnum::Single + : vector_size == thread_elements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; + + using WarpGemm = WarpGemmDispatcher; + + using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index c8f874acd6..4030783ecc 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -9,6 +9,26 @@ namespace ck_tile { +template +struct has_a_tile_access_pattern : std::false_type +{ +}; + +template +struct has_a_tile_access_pattern> : std::true_type +{ +}; + +template +struct has_b_tile_access_pattern : std::false_type +{ +}; + +template +struct has_b_tile_access_pattern> : std::true_type +{ +}; + template struct UniversalGemmBasePolicy { @@ -30,8 +50,25 @@ struct UniversalGemmBasePolicy static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; - static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked; - static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked; + // Default tile access patterns + static constexpr auto DefaultATileAccessPattern = tile_distribution_pattern::thread_raked; + static constexpr auto DefaultBTileAccessPattern = tile_distribution_pattern::thread_raked; + + static constexpr auto getATileAccessPattern() + { + if constexpr(has_a_tile_access_pattern::value) + return Derived::ATileAccessPattern; + else + return DefaultATileAccessPattern; + } + + static constexpr auto getBTileAccessPattern() + { + if constexpr(has_b_tile_access_pattern::value) + return Derived::BTileAccessPattern; + else + return DefaultBTileAccessPattern; + } template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() @@ -168,11 +205,12 @@ struct UniversalGemmBasePolicy { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t VecLoadSize = GetVectorSizeB(); - using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; constexpr auto BK0 = number{}; constexpr auto BK1 = number{}; @@ -500,23 +538,25 @@ struct UniversalGemmBasePolicy // Tile: MPerBlock X KPerBlock if constexpr(std::is_same_v) { - using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; return TileEncodingPattern::make_2d_static_tile_distribution(); } // Tile: KPerBlock X MPerBlock else { - using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; return TileEncodingPattern::make_2d_static_tile_distribution(); } } @@ -536,23 +576,25 @@ struct UniversalGemmBasePolicy // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) { - using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; return TileEncodingPattern::make_2d_static_tile_distribution(); } // Tile: NPerBlock X KPerBlock else { - using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; return TileEncodingPattern::make_2d_static_tile_distribution(); } } @@ -573,7 +615,7 @@ struct UniversalGemmBasePolicy KPerBlock, MPerBlock, VecLoadSize, - ATileAccessPattern, + getATileAccessPattern(), NumWaveGroups>; return TileEncodingPattern::make_shuffled_2d_static_tile_distribution(); } @@ -594,7 +636,7 @@ struct UniversalGemmBasePolicy KPerBlock, NPerBlock, VecLoadSize, - BTileAccessPattern, + getBTileAccessPattern(), NumWaveGroups>; return TileEncodingPattern::make_shuffled_2d_static_tile_distribution(); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp index 926f63b5a9..9e40e1f08c 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -15,9 +15,6 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC using Base::I1; using Base::I2; - using Base::ATileAccessPattern; - using Base::BTileAccessPattern; - template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ() { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index eea8038edf..f9278bf985 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -15,9 +15,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC using Base::I1; using Base::I2; - using Base::ATileAccessPattern; - using Base::BTileAccessPattern; - template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() { diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 44e2433060..1ca7f4fc7d 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -11,6 +11,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -mllvm -enable-noalias-to-md-conversion=0 ) +set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp) @@ -60,6 +61,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12") target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() + if(GPU_TARGETS MATCHES "gfx95") + add_gtest_executable(test_ck_tile_gemm_pipeline_comp_async test_gemm_pipeline_comp_async.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_comp_async PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS}) + endif() + if(GPU_TARGETS MATCHES "gfx11|gfx12") # On Radeon devices, build the WMMA version instead add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_comp_async.cpp b/test/ck_tile/gemm/test_gemm_pipeline_comp_async.cpp new file mode 100644 index 0000000000..c41d40937d --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_comp_async.cpp @@ -0,0 +1,17 @@ +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileGemmPipelineCompAsync + : public TestCkTileGemmPipeline> +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelineCompAsync + +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompAsync, KernelTypesCompAsync); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index a55cd100c1..243a823653 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -26,9 +26,10 @@ using Intrawave = ck_tile::integral_constant; -using Mem = ck_tile::integral_constant; -using CompV3 = ck_tile::integral_constant; -using CompV4 = ck_tile::integral_constant; +using Mem = ck_tile::integral_constant; +using CompV3 = ck_tile::integral_constant; +using CompV4 = ck_tile::integral_constant; +using CompAsync = ck_tile::integral_constant; using Persistent = std::true_type; using NonPersistent = std::false_type; @@ -129,6 +130,10 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4> >; +using KernelTypesCompAsync = ::testing::Types< + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync> +>; + using KernelTypesCompV4Wmma = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>, diff --git a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc index c824d034a9..f793f81cc9 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc @@ -10,18 +10,25 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 1024; - constexpr int K = 320; + std::vector Ks; + for (auto K_count: {2, 3, 4, 10, 11}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } for(int M : Ms) { - if constexpr(std::is_same_v) + for(int K : Ks) { - EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); - } - else - { - this->Run(M, N, K); + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } } } } @@ -30,7 +37,12 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 1024; - constexpr int K = 320; + + std::vector Ks; + for (auto K_count: {2, 3, 4, 10, 11}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } constexpr int VecLoadSize = (std::is_same_v || std::is_same_v || std::is_same_v) @@ -39,22 +51,25 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM) for(int M : Ms) { - if constexpr(std::is_same_v) + for (int K: Ks) { - if(M % VecLoadSize == 0) + if constexpr(std::is_same_v) { - this->Run(M, N, K); + if(M % VecLoadSize == 0) + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } } else { - EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + this->Run(M, N, K); } } - else - { - this->Run(M, N, K); - } } } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index af4f8d3d38..01bc3d7522 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -37,7 +37,8 @@ enum struct GemmPipelineType { Mem, CompV3, - CompV4 + CompV4, + CompAsync }; template @@ -70,6 +71,15 @@ struct GemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; } }; +template +struct GemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync; + using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; } +}; + template class TestCkTileGemmPipeline : public ::testing::Test { @@ -110,7 +120,8 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr bool kPadK = PadK; constexpr bool preshuffle = Preshuffle; - constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false; + constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4 || + PipelineType == GemmPipelineType::CompAsync); // TODO: For now - but this should also be a test parameter constexpr bool TransposeC = false; From a4ab33f539ac9d7209c6274958dc0285eacf3e78 Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Thu, 2 Oct 2025 20:09:49 +0600 Subject: [PATCH 006/172] Fix building test_fmha_bwd_fp32 on SLES15 (#2962) --- test/ck_tile/fmha/test_fmha_bwd_fp32.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp b/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp index d409d0dd30..09010d4b22 100644 --- a/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp @@ -15,6 +15,6 @@ const auto HDimValues = Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tupl const auto ModeValues = Values(mode_enum::batch, mode_enum::group); -constexpr std::string init_method = "uf"; +const std::string init_method = "uf"; #include "test_fmha_bwd.inc" From 6fc28ab4934d3668bf4ec96db1e082cf26b11384 Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Thu, 2 Oct 2025 12:13:51 -0600 Subject: [PATCH 007/172] [CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle (#2897) * [CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle When TransposeC and QuantPreshuffle are both true, Aquant generates correct result. * [CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle - Add unit tests * Fix bug in is_quantpreshuffle_enabled * clang format --------- Co-authored-by: ThomasNing --- .../block_universal_gemm_as_aquant_bs_cr.hpp | 41 ++++++++++++--- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 2 +- .../pipeline/tile_gemm_quant_traits.hpp | 3 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 4 +- .../gemm_block_scale/test_gemm_quant_base.hpp | 1 + .../test_gemm_quant_fixtures.hpp | 52 +++++++++++++++++-- .../test_gemm_quant_typed.cpp | 21 +++++++- 7 files changed, 109 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index d4bece1a83..cb20bdbd50 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -346,13 +346,40 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { if constexpr(Traits::TransposeC) // transposed C { - static_assert(false, - "It is not supported yet to enable both Preshuffle " - "and TransposeC."); - // TODO: - // A new tile distribution is needed for the Preshuffle and - // Transpose combination. For instance, with mnk at 16x16x32, lanes - // 0-15, 16-31, 32-47, and 48-63 must load the same elements of AQ. + constexpr auto tbuf_offset = number< + typename CBlockTensor::ThreadTensorDesc{}.calculate_offset( + merge_sequences(sequence{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter]; + auto pull_from_lane = (__lane_id() & (Traits::WarpGemm::kN - 1)) * + Traits::AQPerBlock + + kQScale; + + // cross lane ops + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } + + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * + scale_reg_f); + }); } else { diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index a0b6fc5821..bba2bc8400 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -73,7 +73,7 @@ struct is_quantpreshuffle_enabled }; template -struct is_quantpreshuffle_enabled +struct is_quantpreshuffle_enabled> { static constexpr bool value = T::PreshuffleQuant; }; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp index 52a326a897..c4429b76f9 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp @@ -39,6 +39,7 @@ template struct TileGemmQuantTraits @@ -62,7 +63,7 @@ struct TileGemmQuantTraits using AsLayout = ALayout_; using BsLayout = BLayout_; - static constexpr bool TransposeC = false; + static constexpr bool TransposeC = TransposeC_; static constexpr bool UseStructuredSparsity = false; static constexpr index_t NumWaveGroups = 1; static constexpr bool UsePersistentKernel = UsePersistentKernel_; diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 93a13ba5af..3a49e69c37 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -7,7 +7,9 @@ list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") # Typed Test Suite for GEMM Quantization - add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp) + add_gtest_executable(test_tile_gemm_quant_typed + test_gemm_quant_typed.cpp + ) target_compile_options(test_tile_gemm_quant_typed PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 355e9fce32..80167a1d21 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -87,6 +87,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test QuantType, ALayout, BLayout, + GemmConfig::TransposeC, DoubleSmemBuffer>; // Let the derived class create the appropriate pipeline and epilogue diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 98f88f4d53..21eabd6041 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -41,6 +41,22 @@ struct GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 32; }; +struct GemmConfigPreshuffleQuant : public GemmConfigBase +{ + static constexpr bool PreshuffleQuant = true; +}; + +struct GemmConfigTransposeC : public GemmConfigBase +{ + static constexpr bool TransposeC = true; +}; + +struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase +{ + static constexpr bool PreshuffleQuant = true; + static constexpr bool TransposeC = true; +}; + struct GemmConfigPreshuffleB { static constexpr bool kPadM = false; @@ -100,6 +116,24 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase + auto shuffle_aq(const ck_tile::HostTensor* t, int block_aq_k) + { + if(t->get_lengths().size() != 2) + { + throw std::runtime_error("Host tensor is not rank 2 tensor."); + } + int m_ = t->get_lengths()[0]; + int aqk_ = t->get_lengths()[1]; + if(aqk_ % block_aq_k != 0) + { + throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k."); + } + ck_tile::HostTensor t_view({m_, aqk_ / block_aq_k, block_aq_k}); + std::copy(t->begin(), t->end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {1, 0, 2}); + } + // AQuant-specific data generation void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) { @@ -150,7 +184,17 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase aq_shuffle_host = + shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize); + aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data()); + } + else + { + aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + } b_k_n_dev_buf.ToDevice(b_k_n.data()); // Create args for kernel execution @@ -245,7 +289,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase; // Type combinations for each quantization type // clang-format off using AQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false std::tuple, std::tuple, std::tuple, - std::tuple + std::tuple, + + // PreshuffleQuant = false && TransposeC = true + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = true && TransposeC = false + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = true && TransposeC = true + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on From cadafde722f838d1fc0b08130cd4fca168acc29c Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Thu, 2 Oct 2025 11:15:24 -0700 Subject: [PATCH 008/172] add the check of granularity for atomic add (#2959) --- .../impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp | 4 ++++ test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 7e9020d796..02639dbf3e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -682,6 +682,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK{}] <= 1 && (arg.KBatch > 1)) + { + return false; + } else { if constexpr(NXdlPerWave32 > 0) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc index f793f81cc9..66ef05b0ba 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc @@ -11,7 +11,7 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM) std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 1024; std::vector Ks; - for (auto K_count: {2, 3, 4, 10, 11}) + for(auto K_count : {2, 3, 4, 10, 11}) { Ks.push_back(K_count * TestFixture::K_Tile); } @@ -36,10 +36,10 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM) TYPED_TEST(TEST_SUITE_NAME, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 1024; + constexpr int N = 1024; std::vector Ks; - for (auto K_count: {2, 3, 4, 10, 11}) + for(auto K_count : {2, 3, 4, 10, 11}) { Ks.push_back(K_count * TestFixture::K_Tile); } @@ -51,7 +51,7 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM) for(int M : Ms) { - for (int K: Ks) + for(int K : Ks) { if constexpr(std::is_same_v) From 0a30c3063068dcefea2291309fbe269812d06956 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Thu, 2 Oct 2025 11:54:45 -0700 Subject: [PATCH 009/172] fix build on legacy systems without cpp20 compiler (#2958) Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc index 8f24c9bfe1..3f15e8c6aa 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc +++ b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc @@ -31,7 +31,7 @@ TYPED_TEST(TestCkTileBatchedGemm, Basic) if(ck_tile::get_device_name() != "gfx950") { - gemmParams.emplace_back(256, 256, 128, 2); + gemmParams.push_back({256, 256, 128, 2}); } for(auto& params : gemmParams) From 4c98535456c468cbd36d39de4a92406fa3a012b6 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 3 Oct 2025 07:08:49 -0700 Subject: [PATCH 010/172] fix compilation errors on RHEL8 and SLES15 (#2967) --- .../gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp index df51a2aa27..4c54ec85c1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp @@ -196,7 +196,7 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1>; + using DLayout = remove_cvref_t>; if constexpr(is_same::value) return Number{}; else @@ -253,7 +253,7 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1{}([&](auto i) { DsLengths[i] = out_lengths; - using DLayout = ::std::__remove_cvref_t>; + using DLayout = remove_cvref_t>; if constexpr(is_same::value) { DsStrides[i] = {arg.StrideDs[i], 1}; From b4a4aa2b64a7a94ab04126545a3dc4f6d3eba847 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Fri, 3 Oct 2025 09:46:13 -0700 Subject: [PATCH 011/172] [CK Tile] CShuffle Tile Permute N all warp compatible (#2966) * solve the hard_code issue of kM2 * clang format --- .../ops/epilogue/cshuffle_epilogue.hpp | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index e0a39a5aea..5918ec806b 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -433,8 +433,13 @@ struct CShuffleEpilogue const ScaleM& scale_m = {}, const ScaleN& scale_n = {}) { + static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size(); + + static_assert(MPerXdl % RowsPerLane == 0, + "CShuffle (permuteN): MPerXdl must be divisible by per-lane row count."); + constexpr int kM0 = MWave; - constexpr int kM2 = 4; + constexpr int kM2 = RowsPerLane; constexpr int kM1 = MPerXdl / kM2; constexpr int kN0 = NWave; @@ -515,32 +520,25 @@ struct CShuffleEpilogue // Pack 4 “rows per lane” as you already do static_for<0, NRepeat, 1>{}([&](auto n_idx) { // source indices in shuffle_acc: (n_idx * product(Y) + row) - const index_t base = n_idx * c_warp_y_lengths.product(); + const index_t plane = c_warp_y_lengths.product(); // local lambda to fuse scale (if present) and convert - auto emit = [&](index_t out_idx, index_t src_row) { - AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row]; - + static_for<0, kM2, 1>{}([&](auto m_lane) { + const int src = n_idx * plane + m_lane; // source row in this N-plane + const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output + AccDataType v = shuffle_acc.get_thread_buffer()[src]; if constexpr(has_scalar_scales) { v = static_cast(v * scale_m * scale_n); } else if constexpr(has_scales && !has_scalar_scales) { - // same linear index mapping on the permuted distribution - const auto s_m = static_cast(sm_tile.get_thread_buffer()[out_idx]); - const auto s_n = static_cast(sn_tile.get_thread_buffer()[out_idx]); - v = static_cast(v * s_m * s_n); + const auto sm = static_cast(sm_tile.get_thread_buffer()[dst]); + const auto sn = static_cast(sn_tile.get_thread_buffer()[dst]); + v = static_cast(v * sm * sn); } - - c_out_tensor.get_thread_buffer()[out_idx] = type_convert(v); - }; - - // Your current packing pattern (rows 0..3, spaced by NRepeat) - emit(n_idx + 0 * NRepeat, 0); - emit(n_idx + 1 * NRepeat, 1); - emit(n_idx + 2 * NRepeat, 2); - emit(n_idx + 3 * NRepeat, 3); + c_out_tensor.get_thread_buffer()[dst] = type_convert(v); + }); }); // store/update From 58983a323287d41dff8b37c5318942d7159559dc Mon Sep 17 00:00:00 2001 From: Geo Min Date: Fri, 3 Oct 2025 12:50:16 -0700 Subject: [PATCH 012/172] [TheRock CI] Bumping hash for TheRock (#2972) * Adding new hash for TheRock * Removing package --- .github/workflows/therock-ci-linux.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 695fb1d913..25b345880b 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -41,7 +41,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: 409f43ad9d564454bb1b23f8c8aa15d6b9d25200 + ref: 3f62012a748df3a3099c51fa95d104db643a4588 # 10-03-2025 commit path: "TheRock" - name: Runner Health Settings @@ -54,6 +54,7 @@ jobs: - name: Patch rocm-libraries run: | + rm ./TheRock/patches/amd-mainline/rocm-libraries/0009-Use-workgroupMappingDim-in-rocroller_host.patch git config --global --add safe.directory '*' git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch From 96efe2f4855d643c2f88ff8d67eab6f21461fce1 Mon Sep 17 00:00:00 2001 From: msaffari-amd Date: Mon, 6 Oct 2025 12:00:58 +0200 Subject: [PATCH 013/172] ck tile engine integrate with gemm unit tests (#2601) * first try to understand how tile engine works * 1st implemented unit tests * manage different types for unit tests * manage using different config files to have different unit tests * manage different layouts * making instances and running them by unit test * Add reference calculation * manage different input dimension combination * add splitk to unit tests. clean code. * remove unused files * clean and test with a simple json file --- test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/gemm_tile_engine/CMakeLists.txt | 237 ++++++++++++++++++ test/ck_tile/gemm_tile_engine/README.md | 27 ++ .../configs/simple_test_config.json | 89 +++++++ .../gemm_tile_engine/test_gemm_simple.cpp | 223 ++++++++++++++++ 5 files changed, 577 insertions(+) create mode 100644 test/ck_tile/gemm_tile_engine/CMakeLists.txt create mode 100644 test/ck_tile/gemm_tile_engine/README.md create mode 100644 test/ck_tile/gemm_tile_engine/configs/simple_test_config.json create mode 100644 test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index b92888b1f1..04be25f30a 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -30,3 +30,4 @@ add_subdirectory(reduce) add_subdirectory(epilogue) add_subdirectory(atomic_add_op) add_subdirectory(fmha) +add_subdirectory(gemm_tile_engine) diff --git a/test/ck_tile/gemm_tile_engine/CMakeLists.txt b/test/ck_tile/gemm_tile_engine/CMakeLists.txt new file mode 100644 index 0000000000..8a3e9e1990 --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/CMakeLists.txt @@ -0,0 +1,237 @@ +# ============================================================================ +# GEMM Tile Engine Unit Tests +# +# This CMake file creates unit tests for tile_engine generated GEMM kernels. +# It follows the exact same build patterns as tile_engine for consistency +# and reliability. Each kernel configuration gets its own test executable. +# ============================================================================ + +# Locate tile_engine GEMM scripts directory +set(TILE_ENGINE_GEMM_DIR "${PROJECT_SOURCE_DIR}/tile_engine/ops/gemm") + +if(NOT EXISTS ${TILE_ENGINE_GEMM_DIR}) + message(WARNING "Tile engine directory not found: ${TILE_ENGINE_GEMM_DIR}") + return() +endif() + +# ============================================================================ +# create_individual_gemm_test_target +# +# Creates a single test executable for a specific kernel configuration. +# Mirrors tile_engine's create_individual_gemm_target function for consistency. +# +# Parameters: +# datatype - Data type (fp16, bf16, fp32, etc.) +# layout - Matrix layout (rcr, rrr, ccr, crr) +# config_name - Configuration file name without .json extension +# trait - Kernel trait combination string +# tile_config - Tile configuration parameters +# config_json - Full path to JSON configuration file +# ============================================================================ +function(create_individual_gemm_test_target datatype layout config_name trait tile_config config_json) + set(target_name "test_gemm_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}") + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}") + + # Generated header path for this specific kernel configuration + set(test_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + + # Generate kernel header using tile_engine's Python script + add_custom_command( + OUTPUT ${test_header} + COMMAND ${Python3_EXECUTABLE} ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${config_json} + --gen_single + --kernel_name "test_gemm_${datatype}_${layout}_${trait}_${tile_config}" + --tile_config "${tile_config}" + --trait_combo "${trait}" + DEPENDS ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py ${config_json} + COMMENT "Generating test header ${test_header}" + VERBATIM + ) + + # Create GTest executable for this kernel configuration + add_gtest_executable(${target_name} + ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_simple.cpp + ) + + # Ensure header is generated before compilation + set(header_target "${target_name}_header") + add_custom_target(${header_target} DEPENDS ${test_header}) + add_dependencies(${target_name} ${header_target}) + + # Configure GPU architectures for HIP compilation + set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS}) + + # Define preprocessor macros for generated header location + target_compile_definitions(${target_name} PRIVATE + GEMM_SINGLE_INSTANCE_HPP="${test_header}" + ) + + # Include directories for headers and dependencies + target_include_directories(${target_name} PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_BINARY_DIR}/include + ${PROJECT_SOURCE_DIR} # Root directory for tile_engine access + ${GTEST_INCLUDE_DIRS} + ) + + # Compiler options matching tile_engine requirements + target_compile_options(${target_name} PRIVATE + -Wno-undefined-func-template # Suppress template warnings + -Wno-float-equal # Allow floating point comparisons + --offload-compress # Enable GPU code compression + -include ${test_header} # Auto-include generated header + ) + + message(STATUS " Created test target: ${target_name}") +endfunction() + +# ============================================================================ +# build_gemm_test_targets +# +# Builds all test targets for a specific datatype/layout/config combination. +# Uses tile_engine's two-step process: list kernels, then generate tests. +# +# Parameters: +# datatype - Data type (fp16, bf16, fp32, etc.) +# layout - Matrix layout (rcr, rrr, ccr, crr) +# config_name - Configuration file name without .json extension +# ============================================================================ +function(build_gemm_test_targets datatype layout config_name) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}") + + # Locate and validate configuration file + set(config_filename "${config_name}.json") + set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}") + message(STATUS " Using test config: ${config_filename}") + + if(NOT EXISTS ${json_blob}) + message(WARNING "Test config file not found: ${json_blob}") + return() + endif() + + # Prepare build directory for this configuration + file(MAKE_DIRECTORY ${working_path}) + + # STEP 1: Discovery phase - list all valid kernel configurations + message(STATUS " Listing kernel configurations for ${datatype}_${layout}...") + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --list_kernels + WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} + RESULT_VARIABLE ret + OUTPUT_VARIABLE list_output + ERROR_VARIABLE list_error + ) + + if(NOT ret EQUAL 0) + message(WARNING "Failed to list kernels for ${datatype}_${layout}: ${list_error}") + return() + endif() + + # Validate kernel discovery results + if(EXISTS ${working_path}/gemm_kernel_count.txt) + file(READ ${working_path}/gemm_kernel_count.txt kernel_count) + string(STRIP "${kernel_count}" kernel_count) + message(STATUS " Found ${kernel_count} test configurations for ${datatype}_${layout}") + else() + message(WARNING "Kernel count file not found for ${datatype}_${layout}") + return() + endif() + + # STEP 2: Generation phase - create test targets for each discovered kernel + if(EXISTS ${working_path}/gemm_kernel_list.txt) + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + set(test_count 0) + foreach(line IN LISTS kernel_lines) + # Parse kernel specification format: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(LENGTH parts parts_len) + if(parts_len EQUAL 3) + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Generate test target for this kernel configuration + create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}") + math(EXPR test_count "${test_count} + 1") + endif() + endforeach() + message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}") + else() + message(WARNING "Kernel list file not found for ${datatype}_${layout}") + endif() +endfunction() + +# ============================================================================ +# MAIN EXECUTION - Test Target Generation +# ============================================================================ + +message(STATUS "=== Starting GEMM Tile Engine Test Configuration ===") +message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + +# GPU architecture filtering - only build tests for supported architectures +set(GEMM_TEST_GPU_TARGETS "") +set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201") + +foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_TEST_GPU_TARGETS ${target}) + message(STATUS " Adding GPU target for tests: ${target}") + endif() +endforeach() + +# Early exit if no compatible GPU architectures are available +if(NOT GEMM_TEST_GPU_TARGETS) + message(WARNING "Skipping GEMM Tile Engine tests: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + return() +endif() + +message(STATUS "Building GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}") + +# ============================================================================ +# Test Configuration Matrix +# ============================================================================ + +# Available test configurations (minimal set for fast CI/testing) +set(TEST_CONFIGS + "simple_test_config" + # "medium_tiles_config" # Uncomment for broader testing +) + +# Data types for testing (core precision types) +set(TEST_DATATYPES "fp16" "bf16") +# Extended data type options: +# set(TEST_DATATYPES "fp16" "bf16" "fp32" "fp64" "int8") + +# Matrix layouts for testing (row-column-row is most common) +set(TEST_LAYOUTS "rcr") +# Extended layout options: +# set(TEST_LAYOUTS "rcr" "rrr" "ccr" "crr") + +# ============================================================================ +# Test Target Generation Loop +# ============================================================================ + +foreach(datatype IN LISTS TEST_DATATYPES) + foreach(layout IN LISTS TEST_LAYOUTS) + foreach(config IN LISTS TEST_CONFIGS) + set(CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config}.json") + if(EXISTS ${CONFIG_FILE}) + message(STATUS "Building tests for ${datatype}_${layout}_${config}") + build_gemm_test_targets("${datatype}" "${layout}" "${config}") + else() + message(WARNING "Config file not found: ${CONFIG_FILE}") + endif() + endforeach() + endforeach() +endforeach() + +message(STATUS "GEMM tile engine tests configured for ${TEST_DATATYPES} with ${TEST_LAYOUTS} layouts using ${TEST_CONFIGS} configurations") diff --git a/test/ck_tile/gemm_tile_engine/README.md b/test/ck_tile/gemm_tile_engine/README.md new file mode 100644 index 0000000000..d99b4115d3 --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/README.md @@ -0,0 +1,27 @@ +# GEMM Tile Engine Unit Tests + +## How It Works + +This unit test system integrates **tile_engine's kernel generation** into automated testing: + +1. **Uses tile_engine scripts directly**: Same Python scripts that generate tile_engine kernels +2. **JSON-based configuration**: Define test parameters in JSON files (like tile_engine) +3. **Build-time generation**: CMake calls tile_engine scripts to generate kernel headers +4. **Individual test executables**: Each kernel configuration becomes a separate test +5. **Tile_engine verification**: Uses exact same error thresholds and validation as tile_engine + +## Tile Engine Integration + +``` +JSON Config → tile_engine Python scripts → Generated Headers → Test Executables +``` + +- **`--list_kernels`**: Get available kernel configurations from JSON +- **`--gen_single`**: Generate individual kernel header for each configuration +- **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations +- **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching + + + + +The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure. diff --git a/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json b/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json new file mode 100644 index 0000000000..c80210b963 --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json @@ -0,0 +1,89 @@ +{ + "problem": { + }, + "tile_config": { + "tile_m": { + "values": [ + 128 + ] + }, + "tile_n": { + "values": [ + 128 + ] + }, + "tile_k": { + "values": [ + 64 + ] + }, + "warp_m": { + "values": [ + 2 + ] + }, + "warp_n": { + "values": [ + 2 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 16 + ] + }, + "warp_tile_n": { + "values": [ + 16 + ] + }, + "warp_tile_k": { + "values": [ + 16 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3", + "compv4" + ] + }, + "scheduler": { + "values": [ + "intrawave" + ] + }, + "epilogue": { + "values": [ + "default" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + }, + "persistent": { + "values": [ + false + ] + } + } +} diff --git a/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp b/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp new file mode 100644 index 0000000000..439dd4f39b --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +// Unit tests for tile_engine generated GEMM kernels +// Tests kernel correctness using tile_engine's verification methodology + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "tile_engine/ops/gemm/gemm_common.hpp" + +// The kernel header is included via compile command line with -include flag +// It defines SelectedKernel struct, KERNEL_NAME, and tensor data types + +// Adaptive error threshold calculation matching tile_engine's implementation +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +/// @brief Function to compare the results of the device and host computations (from tile_engine) +template +bool compare_results(std::string instanceName, + ck_tile::index_t K, + ck_tile::index_t kbatch, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::HostTensor& c_m_n_host_result) +{ + const float max_accumulated_value = + *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_result, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "For " << instanceName << " Relative error threshold is " + << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is " + << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; + + return pass; +} + +// Test parameter structure for matrix dimensions and split_k values +struct GemmTestParams +{ + int m, n, k, split_k; +}; + +class GemmTileEngineTest : public ::testing::TestWithParam +{ + protected: + void SetUp() override + { + auto params = GetParam(); + m_ = params.m; + n_ = params.n; + k_ = params.k; + split_k_ = params.split_k; + + // Calculate strides (following tile_engine pattern) + if constexpr(std::is_same_v) + { + stride_a_ = k_; + } + else + { + stride_a_ = m_; + } + + if constexpr(std::is_same_v) + { + stride_b_ = n_; + } + else + { + stride_b_ = k_; + } + + if constexpr(std::is_same_v) + { + stride_c_ = n_; + } + else + { + stride_c_ = m_; + } + } + + // Test dimensions + int m_, n_, k_, split_k_; + int stride_a_, stride_b_, stride_c_; +}; + +TEST_P(GemmTileEngineTest, BasicFunctionality) +{ + // Get tensor layouts from generated kernel + const ALayout layout_a = ALayout{}; + const BLayout layout_b = BLayout{}; + const CLayout layout_c = CLayout{}; + + // Use split_k from test parameters + int split_k = split_k_; + int stride_a_calc = ck_tile::get_default_stride(m_, k_, 0, is_row_major(layout_a)); + int stride_b_calc = ck_tile::get_default_stride(k_, n_, 0, is_row_major(layout_b)); + int stride_c_calc = ck_tile::get_default_stride(m_, n_, 0, is_row_major(layout_c)); + + // Create host tensors with proper descriptors + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(m_, k_, stride_a_calc, is_row_major(layout_a))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(k_, n_, stride_b_calc, is_row_major(layout_b))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c))); + ck_tile::HostTensor c_m_n_host_result( + ck_tile::host_tensor_descriptor(m_, n_, stride_c_calc, is_row_major(layout_c))); + + // Initialize input tensors with uniform random distribution [-1.0, 1.0] (matches tile_engine) + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + + // Allocate GPU device memory + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + // Copy data to device and zero output buffer + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + // Calculate reference result on host for verification + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_result); + + // Create GEMM kernel arguments + ck_tile::GemmHostArgs gemm_args(a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + split_k, + m_, + n_, + k_, + stride_a_calc, + stride_b_calc, + stride_c_calc); + + // Configure kernel execution for maximum speed (no timing, no debug output) + ck_tile::stream_config stream_config{nullptr, // stream + false, // time_kernel (disable timing for speed) + 0, // log_level (disable debug output) + 0, // n_warmup + 1, // n_repeat + false, // is_gpu_timer (unused when time_kernel=false) + false, // flush_cache + 1}; // rotating_count + + // Launch the generated kernel (no timing overhead for fastest execution) + try + { + SelectedKernel::launch(gemm_args, stream_config); + // Kernel launched successfully if no exception thrown + } + catch(const std::exception& e) + { + FAIL() << "Kernel launch failed: " << e.what(); + } + + // Copy result back from device + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + // Verify results using tile_engine's adaptive error thresholds + bool verification_passed = compare_results( + KERNEL_NAME, k_, split_k, c_m_n_dev_result, c_m_n_host_result); + + EXPECT_TRUE(verification_passed) << "GEMM result verification failed"; +} + +TEST_P(GemmTileEngineTest, KernelInfo) +{ + // Simple test to verify kernel information is available + EXPECT_TRUE(strlen(KERNEL_NAME) > 0) << "Kernel name should not be empty"; + + std::cout << "Testing kernel: " << KERNEL_NAME << std::endl; + std::cout << "Problem size: " << m_ << "x" << n_ << "x" << k_ << " with split_k=" << split_k_ + << std::endl; +} + +// Define test parameters for GEMM verification +INSTANTIATE_TEST_SUITE_P(GemmVerification, + GemmTileEngineTest, + ::testing::Values(GemmTestParams{256, 256, 128, 1}, + GemmTestParams{256, 256, 1024, 1}, + GemmTestParams{256, 512, 512, 1}, + GemmTestParams{512, 256, 512, 1}), + [](const ::testing::TestParamInfo& param_info) { + return std::to_string(param_info.param.m) + "x" + + std::to_string(param_info.param.n) + "x" + + std::to_string(param_info.param.k) + "_splitk" + + std::to_string(param_info.param.split_k); + }); From d4761d7807da0a9205af0e2684e5a1a74e0052ad Mon Sep 17 00:00:00 2001 From: Geo Min Date: Mon, 6 Oct 2025 08:38:38 -0700 Subject: [PATCH 014/172] Fixing hash (#2973) --- .github/workflows/therock-ci-linux.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 25b345880b..ce8ab6120a 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -41,7 +41,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: 3f62012a748df3a3099c51fa95d104db643a4588 # 10-03-2025 commit + ref: dc05d637054ad197c84b00e24b6262af0ec797c6 # 10-03-2025 commit path: "TheRock" - name: Runner Health Settings @@ -54,7 +54,6 @@ jobs: - name: Patch rocm-libraries run: | - rm ./TheRock/patches/amd-mainline/rocm-libraries/0009-Use-workgroupMappingDim-in-rocroller_host.patch git config --global --add safe.directory '*' git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch From 19415d0b6f7766e0523baad10ef0a53232b1defd Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Mon, 6 Oct 2025 15:43:23 -0400 Subject: [PATCH 015/172] fix: nil performance results for gemm examples (#2950) --- .../03_gemm/gemm_splitk_two_stage_invoker.hpp | 7 +- .../03_gemm/gemm_splitk_two_stage_reduce.cpp | 27 ++- .../03_gemm/universal_gemm_invoker.hpp | 7 +- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 181 +++++++++--------- .../grouped_gemm_preshuffle.cpp | 150 +++++++-------- .../17_grouped_gemm/quant_grouped_gemm.cpp | 26 ++- example/ck_tile/18_flatmm/flatmm_basic.cpp | 36 ++-- 7 files changed, 208 insertions(+), 226 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp index 8c7589dabb..9ece1638b5 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp @@ -252,15 +252,14 @@ struct SplitKTwoStageInvoker const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - Run(has_hot_loop_, tail_number_, MemoryOpSet{}); + return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); } else { - Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index f200332588..dd13ed7bba 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -275,30 +275,29 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } - return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { // For workspace mode, always use SET operation since each K-split writes to separate memory - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } /** diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index e0d97a50db..d0fd69b1e2 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -185,15 +185,14 @@ struct UniversalInvoker const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - Run(has_hot_loop_, tail_number_, MemoryOpSet{}); + return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); } else { - Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } }; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 606d98d9e2..f5335c3ec0 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -70,99 +70,95 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - - return ave_time; - }; + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(gemm_descs[0].k_batch == 1) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } else { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } template ( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); }; if(!splitk) { - Run(ck_tile::integral_constant{}); + return ave_time = Run(ck_tile::integral_constant{}); } else { - Run(ck_tile::integral_constant{}); + return ave_time = + Run(ck_tile::integral_constant{}); } - - return ave_time; } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp index 4ce55e8e72..b9d6a4a1bc 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp @@ -76,99 +76,95 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - - return ave_time; - }; + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(gemm_descs[0].k_batch == 1) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } else { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index 409bb173a1..64c9dda64a 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -109,23 +109,19 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); }; - Run(ck_tile::integral_constant{}); - - return ave_time; + return ave_time = Run(ck_tile::integral_constant{}); } #include "quant_run_grouped_gemm_example.inc" diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 280da8d333..3273fac674 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -167,38 +167,38 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } - return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } else { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } template