From f2d367262fa278403aa2ed760f169144c9850a81 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Wed, 1 Oct 2025 18:22:46 -0400 Subject: [PATCH 1/6] tests: add unit tests for grouped_gemm_multi_d persistent kernels (#2941) * feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * style: clang format * WIP: host code for grouped_gemm_multi_d persistent kernel compiles but segfaults * feat(grouped_gemm_multi_d): add functionality to run persistant kernel * feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * style: clang format * fix: incorrect validation method and Dtensor layout in test suite * tests: add unit tests for grouped_gemm_multi_d persistent kernels * parent 5b0af640369b93849335b126d6826b204ccc43a3 author AviralGoelAMD 1758919991 +0000 committer AviralGoelAMD 1759338256 +0000 docs: updated changelog with new feature info fix wp gemm bug when permuteN is false (#2935) * fix wp gemm bug when permuteN is false * code clean --------- Co-authored-by: valarLip <340077269@qq.com> fix copy-paste bug in get_matrix_b; re-enable all tests in multi_abd (#2939) [CK_TILE] FMHA Fix synchronization issue in FWD splitkv combine pipeline (#2934) * Fix validation of rotary embedding with time_kernel_ When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times when time_kernel_ is set). We need to reset the q buffer and rerun all kernels. * Fix synchronization issue in splitkv combine pipeline Different warps can read and then rewrite the same values of lse_acc_lds. Sometimes warps progress at different speeds, one warp can rewrite values that are still being read by another warp. Running the tests multiple times and, preferably, with multiple processes on the same GPU helps to trigger this issue: bin/test_ck_tile_fmha_fwd_fp16 --gtest_repeat=-1 --gtest_shuffle --gtest_throw_on_failure --gtest_filter="TestCkTileFmhaFwd/*KV*" [CK_TILE] Support f32 in FMHA (fwd and bwd) (#2836) * Support 16x16 (MFMA, WMMA) and 32x32 (MFMA) tiles in fwd and bwd BlockDropout Add comments with dropout implementation details Fix performance regression of fwd+dropout * Remove some usage of type punning (reinterpret_cast with ref or ptr) in Philox; * "scalarize" seed and offset, they may come either from kernel args or from device memory (presumably loaded with vector loads). These changes help the compiler to procude more optimal code and reduce register spilling. Use WarpGemmDispatcher instead of explicit WarpGemmMfma... to get CWarpDstrEncoding Use code based on BlockDropout in BlockDropoutBwd Refactor BlockDropout (fwd) Implement BlockDropout (fwd) for WMMA Originally BlockDropout only supported 32x32 tiles (IsWG32 = true), this version supports 16x16 tiles. If MPerBlock > MWarp * 16, it can generate numbers for two 16x16 tiles, similarly to BlockDropoutBwd. Implement BlockDropoutBwd for WMMA Remove MakeRandValLds* functions unused in BlockDropoutBwd Remove unused Run overload from BlockDropoutBwd * Fix regression with philox seed and offset when they exceed 32-bit int __builtin_amdgcn_readfirstlane works with 32-bit values, seed and offset are 64-bit so they get truncated. * Add F32 MFMA warp gemms * Support f32 in fwd FMHA * Implement transpose_vectors for 4-byte types (float) * Fix unexpected implicit f32->uint32 cast in buffer_store<4> __builtin_amdgcn_raw_buffer_store_b32 expects unsigned int but float was passed (implicitly casted to uint). mbuf_t types in other buffer_store<> are changed for consistency. * Support F32 in bwd FMHA hdim = 256 is disabled for now because it uses too much memory on gfx90a * Support Headdim = 48 (divisible by 16) in fwd * Add fp32-specific receipts (800 and 801) * Tune fwd tiles * Tune bwd tiles * Use small tiles only for small seqlen_q * Fix after rebasing * Fix selection of a fallback tile based on bm0 The assumption that the largest bm0 == 128 is not always true for current fp32 tiles. * Remove constraints and adjust filtering for fp32 Custom constraints are no longer needed because now the smallest tile is selected automtically based on seqlen_q. Filters related to qr_async_trload disabled valid fp32 tiles. * Add fp32 tests * Make splitkv and appendkv compile for fp32 only There are no instances yet, but API still must compile when only fp32 is requested. * Remove unimportant f32 instances * Add test_ck_tile_fmha_*_fp32 to REGRESSION_TESTS * Replace magic numbers with a constant, improve comments for dropout * Update changelog * Fix condition that dq_acc must be set to zero when mask is used The change was introduced in #2799 * Replace warp_uniform with recently added amd_wave_read_first_lane * Add hdim = 96 and 192 to fwd Use git ls-files to select candidate files for clang format This change ensures that the files being selected for clang format validation are exactly the ones tracked by the git repo we are testing. This protects against an known issue where the repo being tested contained "stray files" from a previous test. [CK_TILE] Fixing Type Conversions in PassThroughPack8 (#2769) * Change the return type of run_gemm_combinations in the basic tests * Change the return type of run_gemm_combinations in the universal tests * Add universal GEMM tests for bf16 x pk_i4 and fp16 x pk_i4 * Add universal GEMM test for fp8 x pk_i4 * Add basic GEMM tests for bf16 x pk_i4, fp16 x pk_i4 and fp8 x pk_i4. * Add missing GemmTypeConfig * Add missing GemmTypeConfig * No need for utility in test_ck_tile_elementwise_1d * Fix conversion from pk_int4x4_t to bf16x8_t in PassThroughPack8 * Avoid union-based type punning in float_to_bf16_truc_raw to make it constexpr compliant * For consistency also make float_to_bf16_truc_nan_raw constexpr compliant by removing the union * Use a static_cast to bfloat16_t only when CK_TILE_USE_LLVM_BUILTIN_BF16 is enforced * Convert from float to bf16 during compilation rather than using magic values * Fix conversion from pk_int4x4_t to fp8x8_t in PassThroughPack8 * Comment out the basic test for fp16 x pk_i4 as it does not pass * Add missing GemmTypeConfig * Fix conversion from pk_int4x4_t to bf8x8_t in PassThroughPack8 * Add basic and universal GEMM tests for bf8 x pk_i4 * Switch back to amd_assembly_i4_to_fp8x8 in PassThroughPack8 as it works now * Switch back to amd_assembly_i4_to_bf8x8 in PassThroughPack8 as it works now * Remove the inefficient fallbacks for fp8 and bf8 in elementwise/unary_element_wise_operation.hpp * Use explicit macros for enabling and disabling the the constexpr lookup based converters * Fix two failing tests * Avoid union-based type punning in float_to_bf16_rtn_raw to make it constexpr compliant * Use float_to_bf16_rtn_raw instead of float_to_bf16 to create the bf16 lookup table for use in conversions from pk_int4 to bf16 * On ROCm 7.0.1 we need an explicit cast to from uint16_t to bf16_t Grouped Conv Bwd Data out index calculation optimizations (#2917) * Grouped Conv Bwd Data index calculation optimizations * fixes * refactor instances * gfx12 fixes * temporary disable splitK for gfx12 [CK] Fix example_grouped_conv_bwd_data_xdl_fp16 with ksplit = 2 (#2943) root cause: AK1 and BK1 may different in class template. so we need calculate k0 per block separately when ksplit is not 1. [CK][Examples] Extending support for rdna3/4 in following examples: (#2884) * [CK][Examples] Extending support for rdna3/4 in following examples: -example_gemm_xdl_splitk_reduce_multi_d_fp16 -example_gemm_xdl_splitk_reduce_multi_d_bf16 -example_gemm_xdl_splitk_reduce_bf16A_i8B -example_gemm_xdl_splitk_reduce_bfp16 -example_splitk_gemm_bias_e_permute_xdl_fp32 -example_gemm_add_multiply_xdl_fp16 -example_complex_contraction_bilinear_xdl_fp32 -example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 -example_batched_gemm_bias_e_permute_xdl_fp16 -example_gemm_xdl_fp16 -example_gemm_xdl_fp16_av2 -example_gemm_xdl_wavelet_fp16 -example_gemm_add_add_fastgelu_xdl_bf16 -example_gemm_add_add_fastgelu_xdl_fp16 -example_gemm_add_add_fastgelu_xdl_fp32 -example_grouped_gemm_xdl_fp32 -example_grouped_gemm_xdl_fp16 -example_grouped_gemm_xdl_bf16 -example_cgemm_xdl_bf16 -example_cgemm_xdl_fp16 Signed-off-by: Michal Kulikowski * [CK][Examples] Extending support for rdna3/4 in following examples: -example_gemm_xdl_splitk_reduce_multi_d_fp16 -example_gemm_xdl_splitk_reduce_multi_d_bf16 -example_gemm_xdl_splitk_reduce_bf16A_i8B -example_gemm_xdl_splitk_reduce_bfp16 -example_splitk_gemm_bias_e_permute_xdl_fp32 -example_gemm_add_multiply_xdl_fp16 -example_complex_contraction_bilinear_xdl_fp32 -example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 -example_batched_gemm_bias_e_permute_xdl_fp16 -example_gemm_xdl_fp16 -example_gemm_xdl_fp16_av2 -example_gemm_xdl_wavelet_fp16 -example_gemm_add_add_fastgelu_xdl_bf16 -example_gemm_add_add_fastgelu_xdl_fp16 -example_gemm_add_add_fastgelu_xdl_fp32 -example_grouped_gemm_xdl_fp32 -example_grouped_gemm_xdl_fp16 -example_grouped_gemm_xdl_bf16 -example_cgemm_xdl_bf16 -example_cgemm_xdl_fp16 Signed-off-by: Michal Kulikowski --------- Signed-off-by: Michal Kulikowski hot fix check eid range (#2924) * hot fix check eid range * fix clang format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin_amdeng Weight Preshuffle Block Scale gemm support (#2877) * initial commit * remove extra files * fixing errors * updated ReadMe file for mapping of diff quants with diff configs * addressing review comments * addressing review comments * Resolved merge conflicts * [CK TILE GEMM] Replace get_preshuffle_or with is_quantpreshuffle_enabled The get_preshuffle_or was not working as expected, which led to incorrect behavior in the quantization preshuffle process. This change replaces it with the more reliable is_quantpreshuffle_enabled function to properly determine when preshuffle should be applied. * initial commit * debugging * working fp8 for init constant * fp8 working with all inits * updated block level code with comments * changing the loop iter * debugging * debugging * debugging * code fix * code clean up * clang formatted * Add comment * code cleanup * clang formatted * merge conflicts fixes * applying the latest int4 changes to the piepline * fixing test code for updated traits * Adding gtest * review comments addressed * addressing review comments * remove c++20 code * added flush cache changes --------- Co-authored-by: Cong Ma Co-authored-by: root increase time limit for AITER tests (#2948) Code style clean-up and documentation The following changes were made: - Clean-up of variable namings - Addition of README - Removal of num_cu and occupancy args; such options are meant for testing purposes and should not be exposed to the user - Removal of CK_TILE_PIPELINE_MEMORY macro and PipelineTypeTraits class since we only support one pipeline at the moment. Fix timing issue in CK_TILE GEMM example (#2940) * feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * WIP: host code for grouped_gemm_multi_d persistent kernel compiles but segfaults * feat(grouped_gemm_multi_d): add functionality to run persistant kernel * fix: parameterize NumDTensor in GroupedGemmHostArgs and remove lint Fix timing issue in CK_TILE GEMM example (#2940) * style: clang format * refactor: removed unused file [CK] Add command option instance_index and param_mask to run partial ck test (#2889) * [CK] Add command option instance_index and param_mask to run partial ck test Many CK test are instance test. it will loop all instance in the instance library. It causes test often out-of-time if we run test on simulator/emulator. This PR add option instance_index and param_mask to reduce the workload of instance test instance_index: only run test 1 available instance with specified index. param_mask: filter the embedded parameter with specified mask * fix CI error * fix clang format --------- Co-authored-by: illsilin_amdeng [CK_TILE]enhance elementwise test (#2683) * enhance elementwise * fix ci issues --- CHANGELOG.md | 1 + .../17_grouped_gemm/grouped_gemm_multi_d.hpp | 2 + .../run_grouped_gemm_multi_d_example.inc | 10 +- .../test_grouped_gemm_multi_d.cpp | 15 +- .../test_grouped_gemm_multi_d_util.hpp | 135 +++++++++++++++++- 5 files changed, 149 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 438320d907..be613fb78c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added the new api to load different memory sizes to SGPR. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. +* Added support for grouped_gemm kernels to perform multi_d elementwise operation. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data * Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. * Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index d5203a799c..0789452ada 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -76,6 +76,7 @@ struct GemmConfigMemory : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 8; static constexpr bool DoubleSmemBuffer = false; + static constexpr bool Persistent = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; @@ -116,6 +117,7 @@ struct GemmConfigV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr bool Persistent = true; static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index 8f275b069b..e1647c037b 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -182,9 +182,9 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, << std::endl; for(int i = 0; i < group_count; i++) { - Ms.push_back(256 /* + 256 * i */); - Ns.push_back(256 /* + 512 * i */); - Ks.push_back(64 /* + 384 * i */); + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 384 * i); stride_As.push_back(Ks[i]); stride_Bs.push_back(Ks[i]); @@ -256,8 +256,8 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); - ck_tile::FillUniformDistribution{2.f, -2.f}(d0_m_n_tensors[i]); - ck_tile::FillUniformDistribution{2.f, -2.f}(d1_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors[i]); a_m_k_dev_buf.push_back(std::make_unique(a_m_k_tensors[i])); diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp index deea2fc852..c6356a6b2c 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp @@ -31,7 +31,8 @@ template + PipelineType Pipeline_val_, + bool Persistent_val_> struct KernelConfig { using ALayoutType = ALayout_; @@ -56,15 +57,19 @@ struct KernelConfig static constexpr bool DoubleSmemBuffer_ = DoubleSmemBuffer_val_; static constexpr auto Scheduler_ = Scheduler_val_; static constexpr PipelineType Pipeline_ = Pipeline_val_; + static constexpr bool Persistent_ = Persistent_val_; static constexpr int BlockPerCu_ = 1; }; // clang-format off using KernelTypes = ::testing::Types< - // ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory>, // memory - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3>, // v3 - KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4> // v4 + // ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline, Persistent + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4 >; // clang-format on diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp index 4c13b4a7f7..30a61a081b 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -93,7 +93,6 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } - template void invoke_grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr) @@ -229,6 +228,100 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } + void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) + { + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + using GemmUniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = std::conditional_t< + Config::Pipeline_ == (PipelineType::Memory), + ck_tile::GemmPipelineAgBgCrMem, + std::conditional_t, + ck_tile::GemmPipelineAgBgCrCompV4>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + + return ave_time; + }; + if(!splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } + } + public: void Run(const std::vector& Ms, const std::vector& Ns, @@ -379,9 +472,43 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(gemm_descs)); - invoke_grouped_gemm(gemm_descs, - ck_tile::stream_config{nullptr, false, 1}, - gemm_workspace.GetDeviceBuffer()); + if constexpr(Config::Persistent_) + { + std::vector> kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + const bool splitk = gemm_descs[0].k_batch > 1; + for(const auto& arg : gemm_descs) + { + kargs.emplace_back( + ck_tile::UniversalGemmKernelArgs<1, 1, DsDataType::size()>{{arg.a_ptr}, + {arg.b_ptr}, + arg.ds_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + arg.stride_Ds, + arg.stride_E, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + ck_tile::hip_check_error(hipMemcpyWithStream( + kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + + invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr, splitk); + } + else + { + invoke_grouped_gemm(gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer()); + } // Copy results back to host for validation for(int i = 0; i < group_count; i++) From a7da3c68b979bd46c315da09208271d26f5e2900 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Wed, 1 Oct 2025 15:38:07 -0700 Subject: [PATCH 2/6] Add a new gemm pipeline based on ComputeV4 which utilizes async copy API (#2949) * check in pipeline and policy for async load in mi350, need to make sure TileAccessPattern is warp_raked or block_raked solve merge conflicts * fix cmakelists * make it build * fix? buffer async fence * relax fences; it appears it only is needed between pairs of ping-pongs * remove fences * remove fences * cleanup and reformat * add steps annotations * comment all pipeline steps / remove unexplainable syncs * clang-format * add comment * cleanup kernel types for test * fix comment * fix hardcoded warp size * faithfully copy block gemm from compute v4 policy to async policy * make async test gfx950 only * fix cmake logic * set separate compile options for async * refine comment in policy * try update hotloop scheduler * cleanup comments * test more K block sizes * unhardcode Ks, sort of * add large odd test case * fix build for quant * add comment to hot loop scheduler and rename enum * reformat * reword the pipeline description * reformat * address review / add static asserts / typo fix * update changelog --- CHANGELOG.md | 1 + include/ck_tile/core/arch/arch.hpp | 16 + include/ck_tile/ops/gemm.hpp | 2 + .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 531 ++++++++++++++++++ ...ine_ag_bg_cr_comp_async_default_policy.hpp | 101 ++++ ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 108 ++-- .../gemm_aquant_pipeline_ag_bg_cr_policy.hpp | 3 - .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 3 - test/ck_tile/gemm/CMakeLists.txt | 6 + .../gemm/test_gemm_pipeline_comp_async.cpp | 17 + .../gemm/test_gemm_pipeline_kernel_types.hpp | 11 +- .../gemm/test_gemm_pipeline_ut_cases.inc | 51 +- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 15 +- 13 files changed, 803 insertions(+), 62 deletions(-) create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_comp_async.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index be613fb78c..9aadc3dc54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.0.0 ### Added +* Added a compute async pipeline in the CK TILE universal GEMM on gfx950 * Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. * Added the new api to load different memory sizes to SGPR. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 28ded5439a..3b12cf061b 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -275,4 +275,20 @@ CK_TILE_DEVICE static constexpr auto get_device_arch() return gfx12_t{}; #endif } + +enum LLVMSchedGroupMask : int32_t +{ + NONE = 0, + ALU = 1 << 0, + VALU = 1 << 1, + SALU = 1 << 2, + MFMA = 1 << 3, + VMEM = 1 << 4, + VMEM_READ = 1 << 5, + VMEM_WRITE = 1 << 6, + DS = 1 << 7, + DS_READ = 1 << 8, + DS_WRITE = 1 << 9, + ALL = (DS_WRITE << 1) - 1, +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 6e07dbc00e..5edde31cd9 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -40,6 +40,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp new file mode 100644 index 0000000000..2c8d008127 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -0,0 +1,531 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BaseGemmPipelineAgBgCrCompAsync +{ + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::Three; + } + else + { + return TailNumber::Two; + } + } + + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + // Handle all the valid cases. + if(has_hot_loop) + { + if(tail_number == TailNumber::Three) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Two) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + else + { + if(tail_number == TailNumber::Three) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Two) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + // If execution reaches here, it's an invalid tail_number because it wasn't handled above. +#if defined(__HIP_DEVICE_COMPILE__) + __builtin_unreachable(); +#else + throw std::logic_error( + "Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported"); +#endif + } +}; + +/** + * @brief Compute optimized pipeline version async; which is based on V4. + * + * This pipeline introduces asynchronous load from global memory to LDS, + * skipping the intermediate loading into pipeline registers. + */ +template +struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync +{ + using Base = BaseGemmPipelineAgBgCrCompAsync; + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + static_assert(!std::is_same_v, "Not implemented"); + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using BlockGemm = remove_cvref_t())>; + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + static constexpr index_t Preshuffle = Problem::Preshuffle; + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Policy::template IsTransposeC(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); + constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); + constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); + + constexpr index_t WaveSize = get_warp_size(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num; + constexpr auto num_issue = num_buffer_load_inst; + + static_for<0, num_buffer_load_inst, 1>{}([&](auto i) { + // TODO: this will likely need to be redesigned after (1) changes to reading from + // LDS and (2) re-profiling + ignore = i; + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA : 1 + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_READ, 1, 0); // DS read : 1 + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA: 1 + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read :1 + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::MFMA, C_MFMA_Inst_Num / num_issue - 2, 0); // MFMA : 6 + }); + __builtin_amdgcn_sched_barrier(0); + } + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + // TODO support multi-ABD + static_assert(1 == std::tuple_size_v); + static_assert(1 == std::tuple_size_v); + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + // TODO currently fused elementwise are not supported + ignore = a_element_func; + ignore = b_element_func; + static_assert(std::is_same_v, + element_wise::PassThrough>); + static_assert(std::is_same_v, + element_wise::PassThrough>); + static_assert( + std::is_same_v> && + std::is_same_v>, + "Data Type conflict on A and B matrix input data type."); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + // TODO currently only support A matrix row major, B matrix col major; if A matrix is + // col major or B is row major, need to combine with transpose load api + static_assert(!(is_a_col_major || is_b_row_major), + "only support A matrix is row major, B matrix is col major!"); + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + ////////////// global window & register ///////////////// + // A DRAM tile window(s) for load + auto a_tile_windows = generate_tuple( + [&](auto idx) { + return make_tile_window( + a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeADramTileDistribution()); + }, + number{}); + // B DRAM window(s) for load + auto b_tile_windows = generate_tuple( + [&](auto idx) { + return make_tile_window( + b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeBDramTileDistribution()); + }, + number{}); + + // this pipeline has a pair of LDS buffers per logical tile + auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); + auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); + + // LDS tile windows for storing, one per LDS buffer + auto a_copy_lds_window0 = make_tile_window( + a_lds_block0, make_tuple(number{}, number{}), {0, 0}); + + auto a_copy_lds_window1 = make_tile_window( + a_lds_block1, make_tuple(number{}, number{}), {0, 0}); + + auto b_copy_lds_window0 = make_tile_window( + b_lds_block0, make_tuple(number{}, number{}), {0, 0}); + + auto b_copy_lds_window1 = make_tile_window( + b_lds_block1, make_tuple(number{}, number{}), {0, 0}); + + // initialize DRAM window steps, used to advance the DRAM windows + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + + // read A(0), B(0) from DRAM to LDS window(0) + // and advance the DRAM windows + Base::GlobalPrefetchAsync( + a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step); + Base::GlobalPrefetchAsync( + b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step); + + // initialize block gemm + auto block_gemm = BlockGemm(); + + // initialize C block tile + auto c_block_tile = block_gemm.MakeCBlockTile(); + clear_tile(c_block_tile); + + // read A(1), B(1) from DRAM to LDS window(1) + // and advance the DRAM windows + Base::GlobalPrefetchAsync( + a_copy_lds_window1, a_tile_windows[number<0>{}], a_dram_tile_window_step); + Base::GlobalPrefetchAsync( + b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step); + + constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + // register tiles; double buffering -> a register tile corresponds to a LDS tile window + ALdsTile a_block_tile0; + ALdsTile a_block_tile1; + + BLdsTile b_block_tile0; + BLdsTile b_block_tile1; + + // LDS tile windows for reading; + // they share the data pointer with the LDS windows for storing + // but also associate with a distribution to produce a register tile when reading + auto a_lds_ld_window0 = + make_tile_window(a_lds_block0, + make_tuple(number{}, number{}), + {0, 0}, + ALdsTileDistr); + auto a_lds_ld_window1 = + make_tile_window(a_lds_block1, + make_tuple(number{}, number{}), + {0, 0}, + ALdsTileDistr); + auto b_lds_ld_window0 = + make_tile_window(b_lds_block0, + make_tuple(number{}, number{}), + {0, 0}, + BLdsTileDistr); + auto b_lds_ld_window1 = + make_tile_window(b_lds_block1, + make_tuple(number{}, number{}), + {0, 0}, + BLdsTileDistr); + + static_assert(!(is_tile_window_linear_v) && + !(is_tile_window_linear_v) && + !(is_tile_window_linear_v) && + !(is_tile_window_linear_v), + "LDS windows must not be linear"); + + // write to LDS window(0) must complete before the local prefetch + block_sync_lds_direct_load(); + // read A(0), B(0) from LDS window(0) to pipeline registers(0) + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); + // LDS window(0) contents are overwritten below by global prefetch, need to sync + block_sync_lds(); + // read A(2), B(2) from DRAM to LDS window(0) + // and advance the DRAM windows + Base::GlobalPrefetchAsync( + a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step); + Base::GlobalPrefetchAsync( + b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step); + + if(HasHotLoop) + { + // we have had 3 global prefetches so far, indexed (0, 1, 2). + index_t i_global_read = amd_wave_read_first_lane(3); + // alternate ping: (read to register tile(1), use register tile(0) as gemm input) + // pong: (read to register tile(0), use register tile(1) as gemm input) + do + { + // ping + { + // read A(i-1), B(i-1) from LDS window(1) to pipeline registers(1) + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + // LDS window(1) contents are overwritten by global prefetch, need to sync + block_sync_lds(); + // read A(i), B(i) from DRAM to LDS window(1) + // and advance the DRAM windows + Base::GlobalPrefetchAsync(a_copy_lds_window1, + a_tile_windows[number<0>{}], + a_dram_tile_window_step); + Base::GlobalPrefetchAsync(b_copy_lds_window1, + b_tile_windows[number<0>{}], + b_dram_tile_window_step); + // C(i-3) = A(i-3) @ B(i-3) + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + HotLoopScheduler(); + } + // pong + { + // write to LDS window(0) must complete before the local prefetch + block_sync_lds_direct_load(); + // read A(i), B(i) from LDS window(0) to pipeline registers(0) + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); + // LDS window(0) contents are overwritten by global prefetch, need to sync + block_sync_lds(); + // read A(i+1), B(i+1) from DRAM to LDS window(0) + // and advance the DRAM windows + Base::GlobalPrefetchAsync(a_copy_lds_window0, + a_tile_windows[number<0>{}], + a_dram_tile_window_step); + Base::GlobalPrefetchAsync(b_copy_lds_window0, + b_tile_windows[number<0>{}], + b_dram_tile_window_step); + // C(i-2) = A(i-2) @ B(i-2) + block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + HotLoopScheduler(); + } + i_global_read += 2; + } while(i_global_read < num_loop); + } + + // 3 block gemms remaining + if constexpr(TailNum == TailNumber::Three) + { + { + // read A(num_loop-1), B(num_loop-1) from LDS window(1) to pipeline registers(1) + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + // C(num_loop-2) = A(num_loop-2) @ B(num_loop-2) + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + } + { + // read A(num_loop), B(num_loop) from LDS window(0) to pipeline registers(0) + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); + Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); + // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) + block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + } + { + // C(num_loop) = A(num_loop) @ B(num_loop) + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + } + } + else + // 2 block gemms remaining + { + { + // read A(num_loop), B(num_loop) from LDS window(1) to pipeline registers(1) + Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); + Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); + // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + } + { + // C(num_loop) = A(num_loop) @ B(num_loop) + block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + } + } + return c_block_tile; + } + }; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_0, + void* p_smem_1) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem_0, + p_smem_1); + } + + public: + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem_0, + p_smem_1); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp new file mode 100644 index 0000000000..23104375d6 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { +// Default policy for GemmPipelineAgBgCrCompAsync +// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor +// GetBlockGemm implementation is copied from GemmPipelineAgBgCrCompV4DefaultPolicy +struct GemmPipelineAgBgCrCompAsyncDefaultPolicy + : public UniversalGemmBasePolicy +{ + static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked; + static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + return transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackB(); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + return transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + constexpr index_t vector_size = + DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); + constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); + constexpr auto wg_attr_num_access = + !(is_a_load_tr || is_b_load_tr) ? WGAttrNumAccessEnum::Single + : vector_size == thread_elements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; + + using WarpGemm = WarpGemmDispatcher; + + using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index c8f874acd6..4030783ecc 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -9,6 +9,26 @@ namespace ck_tile { +template +struct has_a_tile_access_pattern : std::false_type +{ +}; + +template +struct has_a_tile_access_pattern> : std::true_type +{ +}; + +template +struct has_b_tile_access_pattern : std::false_type +{ +}; + +template +struct has_b_tile_access_pattern> : std::true_type +{ +}; + template struct UniversalGemmBasePolicy { @@ -30,8 +50,25 @@ struct UniversalGemmBasePolicy static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; - static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked; - static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked; + // Default tile access patterns + static constexpr auto DefaultATileAccessPattern = tile_distribution_pattern::thread_raked; + static constexpr auto DefaultBTileAccessPattern = tile_distribution_pattern::thread_raked; + + static constexpr auto getATileAccessPattern() + { + if constexpr(has_a_tile_access_pattern::value) + return Derived::ATileAccessPattern; + else + return DefaultATileAccessPattern; + } + + static constexpr auto getBTileAccessPattern() + { + if constexpr(has_b_tile_access_pattern::value) + return Derived::BTileAccessPattern; + else + return DefaultBTileAccessPattern; + } template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() @@ -168,11 +205,12 @@ struct UniversalGemmBasePolicy { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t VecLoadSize = GetVectorSizeB(); - using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; constexpr auto BK0 = number{}; constexpr auto BK1 = number{}; @@ -500,23 +538,25 @@ struct UniversalGemmBasePolicy // Tile: MPerBlock X KPerBlock if constexpr(std::is_same_v) { - using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; return TileEncodingPattern::make_2d_static_tile_distribution(); } // Tile: KPerBlock X MPerBlock else { - using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; return TileEncodingPattern::make_2d_static_tile_distribution(); } } @@ -536,23 +576,25 @@ struct UniversalGemmBasePolicy // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) { - using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; return TileEncodingPattern::make_2d_static_tile_distribution(); } // Tile: NPerBlock X KPerBlock else { - using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; return TileEncodingPattern::make_2d_static_tile_distribution(); } } @@ -573,7 +615,7 @@ struct UniversalGemmBasePolicy KPerBlock, MPerBlock, VecLoadSize, - ATileAccessPattern, + getATileAccessPattern(), NumWaveGroups>; return TileEncodingPattern::make_shuffled_2d_static_tile_distribution(); } @@ -594,7 +636,7 @@ struct UniversalGemmBasePolicy KPerBlock, NPerBlock, VecLoadSize, - BTileAccessPattern, + getBTileAccessPattern(), NumWaveGroups>; return TileEncodingPattern::make_shuffled_2d_static_tile_distribution(); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp index 926f63b5a9..9e40e1f08c 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -15,9 +15,6 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC using Base::I1; using Base::I2; - using Base::ATileAccessPattern; - using Base::BTileAccessPattern; - template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ() { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index eea8038edf..f9278bf985 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -15,9 +15,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC using Base::I1; using Base::I2; - using Base::ATileAccessPattern; - using Base::BTileAccessPattern; - template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() { diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 44e2433060..1ca7f4fc7d 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -11,6 +11,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -mllvm -enable-noalias-to-md-conversion=0 ) +set(EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp) @@ -60,6 +61,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12") target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() + if(GPU_TARGETS MATCHES "gfx95") + add_gtest_executable(test_ck_tile_gemm_pipeline_comp_async test_gemm_pipeline_comp_async.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_comp_async PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_ASYNC_OPTIONS}) + endif() + if(GPU_TARGETS MATCHES "gfx11|gfx12") # On Radeon devices, build the WMMA version instead add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_comp_async.cpp b/test/ck_tile/gemm/test_gemm_pipeline_comp_async.cpp new file mode 100644 index 0000000000..c41d40937d --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_comp_async.cpp @@ -0,0 +1,17 @@ +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileGemmPipelineCompAsync + : public TestCkTileGemmPipeline> +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelineCompAsync + +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompAsync, KernelTypesCompAsync); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index a55cd100c1..243a823653 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -26,9 +26,10 @@ using Intrawave = ck_tile::integral_constant; -using Mem = ck_tile::integral_constant; -using CompV3 = ck_tile::integral_constant; -using CompV4 = ck_tile::integral_constant; +using Mem = ck_tile::integral_constant; +using CompV3 = ck_tile::integral_constant; +using CompV4 = ck_tile::integral_constant; +using CompAsync = ck_tile::integral_constant; using Persistent = std::true_type; using NonPersistent = std::false_type; @@ -129,6 +130,10 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4> >; +using KernelTypesCompAsync = ::testing::Types< + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync> +>; + using KernelTypesCompV4Wmma = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F16, F16, F32, F16, I64, I64, I32, I16, I16, I16, Intrawave, CompV4>, diff --git a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc index c824d034a9..f793f81cc9 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc @@ -10,18 +10,25 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 1024; - constexpr int K = 320; + std::vector Ks; + for (auto K_count: {2, 3, 4, 10, 11}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } for(int M : Ms) { - if constexpr(std::is_same_v) + for(int K : Ks) { - EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); - } - else - { - this->Run(M, N, K); + if constexpr(std::is_same_v) + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } + else + { + this->Run(M, N, K); + } } } } @@ -30,7 +37,12 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 1024; - constexpr int K = 320; + + std::vector Ks; + for (auto K_count: {2, 3, 4, 10, 11}) + { + Ks.push_back(K_count * TestFixture::K_Tile); + } constexpr int VecLoadSize = (std::is_same_v || std::is_same_v || std::is_same_v) @@ -39,22 +51,25 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM) for(int M : Ms) { - if constexpr(std::is_same_v) + for (int K: Ks) { - if(M % VecLoadSize == 0) + if constexpr(std::is_same_v) { - this->Run(M, N, K); + if(M % VecLoadSize == 0) + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } } else { - EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + this->Run(M, N, K); } } - else - { - this->Run(M, N, K); - } } } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index af4f8d3d38..01bc3d7522 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -37,7 +37,8 @@ enum struct GemmPipelineType { Mem, CompV3, - CompV4 + CompV4, + CompAsync }; template @@ -70,6 +71,15 @@ struct GemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; } }; +template +struct GemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync; + using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; } +}; + template class TestCkTileGemmPipeline : public ::testing::Test { @@ -110,7 +120,8 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr bool kPadK = PadK; constexpr bool preshuffle = Preshuffle; - constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4) ? true : false; + constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4 || + PipelineType == GemmPipelineType::CompAsync); // TODO: For now - but this should also be a test parameter constexpr bool TransposeC = false; From a4ab33f539ac9d7209c6274958dc0285eacf3e78 Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Thu, 2 Oct 2025 20:09:49 +0600 Subject: [PATCH 3/6] Fix building test_fmha_bwd_fp32 on SLES15 (#2962) --- test/ck_tile/fmha/test_fmha_bwd_fp32.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp b/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp index d409d0dd30..09010d4b22 100644 --- a/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd_fp32.cpp @@ -15,6 +15,6 @@ const auto HDimValues = Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tupl const auto ModeValues = Values(mode_enum::batch, mode_enum::group); -constexpr std::string init_method = "uf"; +const std::string init_method = "uf"; #include "test_fmha_bwd.inc" From 6fc28ab4934d3668bf4ec96db1e082cf26b11384 Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Thu, 2 Oct 2025 12:13:51 -0600 Subject: [PATCH 4/6] [CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle (#2897) * [CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle When TransposeC and QuantPreshuffle are both true, Aquant generates correct result. * [CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle - Add unit tests * Fix bug in is_quantpreshuffle_enabled * clang format --------- Co-authored-by: ThomasNing --- .../block_universal_gemm_as_aquant_bs_cr.hpp | 41 ++++++++++++--- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 2 +- .../pipeline/tile_gemm_quant_traits.hpp | 3 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 4 +- .../gemm_block_scale/test_gemm_quant_base.hpp | 1 + .../test_gemm_quant_fixtures.hpp | 52 +++++++++++++++++-- .../test_gemm_quant_typed.cpp | 21 +++++++- 7 files changed, 109 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index d4bece1a83..cb20bdbd50 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -346,13 +346,40 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { if constexpr(Traits::TransposeC) // transposed C { - static_assert(false, - "It is not supported yet to enable both Preshuffle " - "and TransposeC."); - // TODO: - // A new tile distribution is needed for the Preshuffle and - // Transpose combination. For instance, with mnk at 16x16x32, lanes - // 0-15, 16-31, 32-47, and 48-63 must load the same elements of AQ. + constexpr auto tbuf_offset = number< + typename CBlockTensor::ThreadTensorDesc{}.calculate_offset( + merge_sequences(sequence{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter]; + auto pull_from_lane = (__lane_id() & (Traits::WarpGemm::kN - 1)) * + Traits::AQPerBlock + + kQScale; + + // cross lane ops + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } + + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * + scale_reg_f); + }); } else { diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index a0b6fc5821..bba2bc8400 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -73,7 +73,7 @@ struct is_quantpreshuffle_enabled }; template -struct is_quantpreshuffle_enabled +struct is_quantpreshuffle_enabled> { static constexpr bool value = T::PreshuffleQuant; }; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp index 52a326a897..c4429b76f9 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp @@ -39,6 +39,7 @@ template struct TileGemmQuantTraits @@ -62,7 +63,7 @@ struct TileGemmQuantTraits using AsLayout = ALayout_; using BsLayout = BLayout_; - static constexpr bool TransposeC = false; + static constexpr bool TransposeC = TransposeC_; static constexpr bool UseStructuredSparsity = false; static constexpr index_t NumWaveGroups = 1; static constexpr bool UsePersistentKernel = UsePersistentKernel_; diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 93a13ba5af..3a49e69c37 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -7,7 +7,9 @@ list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") # Typed Test Suite for GEMM Quantization - add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp) + add_gtest_executable(test_tile_gemm_quant_typed + test_gemm_quant_typed.cpp + ) target_compile_options(test_tile_gemm_quant_typed PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 355e9fce32..80167a1d21 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -87,6 +87,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test QuantType, ALayout, BLayout, + GemmConfig::TransposeC, DoubleSmemBuffer>; // Let the derived class create the appropriate pipeline and epilogue diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 98f88f4d53..21eabd6041 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -41,6 +41,22 @@ struct GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 32; }; +struct GemmConfigPreshuffleQuant : public GemmConfigBase +{ + static constexpr bool PreshuffleQuant = true; +}; + +struct GemmConfigTransposeC : public GemmConfigBase +{ + static constexpr bool TransposeC = true; +}; + +struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase +{ + static constexpr bool PreshuffleQuant = true; + static constexpr bool TransposeC = true; +}; + struct GemmConfigPreshuffleB { static constexpr bool kPadM = false; @@ -100,6 +116,24 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase + auto shuffle_aq(const ck_tile::HostTensor* t, int block_aq_k) + { + if(t->get_lengths().size() != 2) + { + throw std::runtime_error("Host tensor is not rank 2 tensor."); + } + int m_ = t->get_lengths()[0]; + int aqk_ = t->get_lengths()[1]; + if(aqk_ % block_aq_k != 0) + { + throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k."); + } + ck_tile::HostTensor t_view({m_, aqk_ / block_aq_k, block_aq_k}); + std::copy(t->begin(), t->end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {1, 0, 2}); + } + // AQuant-specific data generation void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) { @@ -150,7 +184,17 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase aq_shuffle_host = + shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize); + aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data()); + } + else + { + aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + } b_k_n_dev_buf.ToDevice(b_k_n.data()); // Create args for kernel execution @@ -245,7 +289,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase; // Type combinations for each quantization type // clang-format off using AQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false std::tuple, std::tuple, std::tuple, - std::tuple + std::tuple, + + // PreshuffleQuant = false && TransposeC = true + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = true && TransposeC = false + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = true && TransposeC = true + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on From cadafde722f838d1fc0b08130cd4fca168acc29c Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Thu, 2 Oct 2025 11:15:24 -0700 Subject: [PATCH 5/6] add the check of granularity for atomic add (#2959) --- .../impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp | 4 ++++ test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 7e9020d796..02639dbf3e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -682,6 +682,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK{}] <= 1 && (arg.KBatch > 1)) + { + return false; + } else { if constexpr(NXdlPerWave32 > 0) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc index f793f81cc9..66ef05b0ba 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc @@ -11,7 +11,7 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM) std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 1024; std::vector Ks; - for (auto K_count: {2, 3, 4, 10, 11}) + for(auto K_count : {2, 3, 4, 10, 11}) { Ks.push_back(K_count * TestFixture::K_Tile); } @@ -36,10 +36,10 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM) TYPED_TEST(TEST_SUITE_NAME, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 1024; + constexpr int N = 1024; std::vector Ks; - for (auto K_count: {2, 3, 4, 10, 11}) + for(auto K_count : {2, 3, 4, 10, 11}) { Ks.push_back(K_count * TestFixture::K_Tile); } @@ -51,7 +51,7 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM) for(int M : Ms) { - for (int K: Ks) + for(int K : Ks) { if constexpr(std::is_same_v) From 0a30c3063068dcefea2291309fbe269812d06956 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Thu, 2 Oct 2025 11:54:45 -0700 Subject: [PATCH 6/6] fix build on legacy systems without cpp20 compiler (#2958) Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc index 8f24c9bfe1..3f15e8c6aa 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc +++ b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc @@ -31,7 +31,7 @@ TYPED_TEST(TestCkTileBatchedGemm, Basic) if(ck_tile::get_device_name() != "gfx950") { - gemmParams.emplace_back(256, 256, 128, 2); + gemmParams.push_back({256, 256, 128, 2}); } for(auto& params : gemmParams)