From db12c41b569bf55181e46bd6fb5c9bd9e15c1a56 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Tue, 11 Nov 2025 17:12:49 +0000 Subject: [PATCH] Merge commit '1b1c46e508c1fd40a03f54114b6b78629032fb4f' into develop --- .github/workflows/therock-ci-linux.yml | 6 +- .github/workflows/therock-test-component.yml | 4 +- .github/workflows/therock-test-packages.yml | 2 +- example/01_gemm/gemm_wmma_fp8_v3.cpp | 10 +- .../38_block_scale_gemm/CMakeLists.txt | 2 +- .../38_block_scale_gemm/gemm_quant_basic.cpp | 4 + .../38_block_scale_gemm/gemm_utils.hpp | 8 + .../blockwise_gemm_pipeline_wmma_selector.hpp | 3 + .../blockwise_gemm_pipeline_wmmaops_base.hpp | 47 +-- .../blockwise_gemm_pipeline_wmmaops_v1.hpp | 302 ++++++++++-------- .../blockwise_gemm_pipeline_wmmaops_v3.hpp | 250 +++++++++------ .../gridwise_ab_transfer_thread_tiles.hpp | 98 +++++- .../grid/gridwise_ab_transfer_wave_tiles.hpp | 6 +- ...ise_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 159 ++++++--- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 25 +- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 14 + include/ck_tile/host/tensor_shuffle_utils.hpp | 98 ++++-- .../gemm/warp/warp_gemm_attribute_wmma.hpp | 1 + include/ck_tile/ops/gemm_quant.hpp | 1 + .../block/block_gemm_quant_common.hpp | 38 +++ ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 22 +- .../block_universal_gemm_as_aquant_bs_cr.hpp | 28 +- .../block_universal_gemm_as_bs_bquant_cr.hpp | 23 +- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 5 +- .../pipeline/tile_gemm_quant_traits.hpp | 5 +- ...mm_wmma_universal_f16_f16_f16_km_kn_mn.hpp | 4 +- ...mm_wmma_universal_f16_f16_f16_km_nk_mn.hpp | 4 +- ...mm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp | 4 +- ...mm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp | 4 +- ...emm_wmma_universal_f16_f8_f16_km_kn_mn.hpp | 3 +- ...emm_wmma_universal_f16_f8_f16_km_nk_mn.hpp | 3 +- ...emm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp | 3 +- ...emm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp | 3 +- ...emm_wmma_universal_f8_f16_f16_km_kn_mn.hpp | 3 +- ...emm_wmma_universal_f8_f16_f16_km_nk_mn.hpp | 3 +- ...emm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp | 3 +- ...emm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp | 3 +- ...emm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp | 3 +- ...emm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp | 3 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 2 +- .../gemm_block_scale/test_gemm_quant_base.hpp | 14 +- .../test_gemm_quant_fixtures.hpp | 24 +- tile_engine/ops/commons/test_benchmark.sh | 3 + tile_engine/ops/commons/test_validation.py | 3 + tile_engine/ops/commons/validation_utils.py | 2 +- tile_engine/ops/gemm/codegen_utils.py | 2 +- tile_engine/ops/gemm/gemm_benchmark.hpp | 2 +- tile_engine/ops/gemm/gemm_benchmark.py | 2 +- .../ops/gemm/gemm_benchmark_single.cpp | 2 +- tile_engine/ops/gemm/gemm_common.hpp | 2 +- tile_engine/ops/gemm/gemm_instance_builder.py | 3 + tile_engine/ops/gemm/gemm_profiler.hpp | 2 +- .../gemm_multi_d/gemm_multi_d_benchmark.hpp | 2 +- .../gemm_multi_d/gemm_multi_d_benchmark.py | 2 +- .../gemm_multi_d_benchmark_single.cpp | 2 +- .../ops/gemm_multi_d/gemm_multi_d_common.hpp | 2 +- .../gemm_multi_d_instance_builder.py | 3 + .../gemm_multi_d/gemm_multi_d_profiler.hpp | 2 +- .../commons/validation_utils.py | 2 +- .../gemm_preshuffle_benchmark.hpp | 3 + .../gemm_preshuffle_benchmark.py | 2 +- .../gemm_preshuffle_benchmark_single.cpp | 2 +- .../gemm_preshuffle_common.hpp | 2 +- .../gemm_preshuffle_instance_builder.py | 4 +- .../gemm_preshuffle_profiler.hpp | 2 +- 65 files changed, 845 insertions(+), 455 deletions(-) create mode 100644 include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index f4d0c0063c..86d134e456 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -53,8 +53,8 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit path: "TheRock" + ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit - name: Setup ccache run: | @@ -77,6 +77,8 @@ jobs: - name: Patch rocm-libraries run: | git config --global --add safe.directory '*' + # Remove patches here if they cannot be applied cleanly, and they have not been deleted from TheRock repo + rm -f ./TheRock/patches/amd-mainline/rocm-libraries/0008-Revert-remove-options-no-enumerate-966.patch git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch - name: Install python deps @@ -128,7 +130,7 @@ jobs: run: | python3 TheRock/build_tools/github_actions/post_build_upload.py \ --run-id ${{ github.run_id }} \ - --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ + --artifact-group ${{ env.AMDGPU_FAMILIES }} \ --build-dir TheRock/build \ --upload diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml index 1ccc1d57bc..27eff4fdb0 100644 --- a/.github/workflows/therock-test-component.yml +++ b/.github/workflows/therock-test-component.yml @@ -51,13 +51,13 @@ jobs: uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: "ROCm/TheRock" - ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit + ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit - name: Run setup test environment workflow uses: './.github/actions/setup_test_environment' with: ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} - AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + ARTIFACT_GROUP: ${{ inputs.amdgpu_families }} OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} VENV_DIR: ${{ env.VENV_DIR }} FETCH_ARTIFACT_ARGS: ${{ fromJSON(inputs.component).fetch_artifact_args }} diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index efb5a6b1a0..81632fce48 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit + ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit - name: "Configuring CI options" env: diff --git a/example/01_gemm/gemm_wmma_fp8_v3.cpp b/example/01_gemm/gemm_wmma_fp8_v3.cpp index 0376820b7b..2f8eac113b 100644 --- a/example/01_gemm/gemm_wmma_fp8_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp8_v3.cpp @@ -13,7 +13,7 @@ using CDataType = ck::bhalf_t; using ComputeTypeA = ck::f8_t; using ComputeTypeB = ck::f8_t; -using ALayout = Row; +using ALayout = Col; using BLayout = Col; using CLayout = Row; @@ -30,13 +30,13 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 64, - 8, 8, + 16, 16, // AK1, BK1 16, 16, 4, 2, + S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 4, 16, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 8, 8, 0, - S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 8, 8, 0, + 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB>; diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 7358d4d749..b1ae9369a2 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -5,7 +5,7 @@ endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp) target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp index b22596537f..d605a2b780 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -419,6 +419,10 @@ int dispatch_group_size_ct(int m, int n, int k, F&& f) int main(int argc, char* argv[]) { +#if CK_TILE_USE_WMMA + return !run_gemm_example(argc, argv); +#else // Use non-preshuffled GemmConfig for 2D block scale support return !run_gemm_example(argc, argv); +#endif } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 589caf88f4..1839c7f98d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -216,6 +216,14 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +template +struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill +{ + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + template constexpr auto BlockGemmPipeline_Selector() { @@ -52,6 +53,7 @@ constexpr auto BlockGemmPipeline_Selector() MRepeat, NRepeat, KPack, + KInner, TransposeC>{}; } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) @@ -75,6 +77,7 @@ constexpr auto BlockGemmPipeline_Selector() MRepeat, NRepeat, KPack, + KInner, TransposeC>{}; } else diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index 265db9166a..abc9720714 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -30,6 +30,7 @@ template struct BlockwiseGemmWmmaops_pipeline_base { @@ -38,6 +39,7 @@ struct BlockwiseGemmWmmaops_pipeline_base static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; using ThisThreadBlock = ThisThreadBlock; @@ -54,15 +56,20 @@ struct BlockwiseGemmWmmaops_pipeline_base static constexpr index_t B_KRow = 1; #endif - static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5); - static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5); + static constexpr auto wmma_gemm = WmmaGemm{}; + + static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner; + static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread); + static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread); static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!"); static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!"); - - static constexpr auto wmma_gemm = - WmmaGemm{}; - static constexpr index_t KRepeat = KPerBlock / KPack; static constexpr auto WmmaK = Number{}; @@ -191,8 +198,7 @@ struct BlockwiseGemmWmmaops_pipeline_base const auto wmma_krow = 0; #endif - // |KRepeat |MRepeat|MWave |KRow |MLane |KPack - return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0); + return make_tuple(0, 0, 0, waveId_m, wmma_krow, wmma_a_idx, 0); } __device__ static auto CalculateBThreadOriginDataIndex() @@ -209,8 +215,7 @@ struct BlockwiseGemmWmmaops_pipeline_base const auto wmma_krow = 0; #endif - // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack - return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0); + return make_tuple(0, 0, 0, waveId_n, wmma_krow, wmma_b_idx, 0); } template @@ -241,7 +246,7 @@ struct BlockwiseGemmWmmaops_pipeline_base return make_tuple(c_thread_m, c_thread_n); } - using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + using Tuple7 = decltype(CalculateAThreadOriginDataIndex()); /** * @brief Constructor for BlockwiseGemmWmmaops_pipeline_base. @@ -261,8 +266,8 @@ struct BlockwiseGemmWmmaops_pipeline_base * repeat dimensions. */ __host__ __device__ - BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), - Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin = CalculateAThreadOriginDataIndex(), + Tuple7 b_origin = CalculateBThreadOriginDataIndex()) : a_thread_copy_(a_origin), b_thread_copy_(b_origin) { static_assert(AWmmaTileDesc::IsKnownAtCompileTime() && @@ -343,12 +348,14 @@ struct BlockwiseGemmWmmaops_pipeline_base Number{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); static constexpr auto b_thread_desc_ = @@ -357,12 +364,14 @@ struct BlockwiseGemmWmmaops_pipeline_base Number{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); // C[M, N, NumRegWmma] @@ -374,9 +383,9 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), - Sequence, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, A_K1, A_K1>; @@ -385,9 +394,9 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), - Sequence, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, B_K1, B_K1>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 5d7c570428..5f731933e2 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -32,6 +32,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v1 { @@ -55,6 +56,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v1 : BlockwiseGemmWmmaops_pipeline_base { using Base = BlockwiseGemmWmmaops_pipeline_base; using Base::I0; using Base::I1; - using Base::WaveSize; using typename Base::HotLoopInstList; using Base::A_K1; @@ -187,6 +191,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -211,27 +217,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0, I0, I0), - a_thread_buf); - + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0, I0, I0, I0), + a_thread_buf); if constexpr(m0 == I0) { if constexpr(ck::is_same::value == true) { static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple( - Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0, I0), + b_thread_buf); }); } else @@ -239,45 +241,60 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto n0) { b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, - make_tuple( - Number{}, n0, I0, I0, I0, I0), + make_tuple(I0, n0, k0, I0, I0, I0, I0), b_block_buf, b_scale_struct.b_scale_thread_bufs( I0)[Number{}], b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), + make_tuple(I0, n0, I0, I0, I0, I0, I0), b_thread_buf); }); } } - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, I0, I0, I0, I0, Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + I0, + I0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + I0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, n0, I0, I0, I0, Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -324,8 +341,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + static_for<0, KInner, 1>{}([&](auto) { + static_for<0, NRepeat, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); }); }); }); @@ -348,20 +367,20 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, I1, I1, Number{})); + make_tuple(Number{}, I1, I1, I1, I1, I1, Number{})); // B[NRepeat, N1, N2, KPack] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, I1, Number{})); + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, I1, I1, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, A_K1, A_K1>; @@ -370,9 +389,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, B_K1, B_K1>; @@ -399,6 +418,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v1 : BlockwiseGemmWmmaops_pipeline_base { using Base = BlockwiseGemmWmmaops_pipeline_base; using Base::I0; using Base::I1; @@ -532,6 +555,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -557,33 +582,22 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_offset) { static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) { static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{}, - m0, - I0, - I0, - I0, - I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0_inner, I0, I0, I0), - a_thread_buf); + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0_inner, I0, I0, I0, I0), + a_thread_buf); }); if constexpr(ck::is_same::value == true) { static_for<0, NRepeat, 1>{}([&](auto n0) { b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, - n0, - I0, - I0, - I0, - I0), + make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0), b_block_buf, b_thread_desc_, - make_tuple(I0, n0, k0_inner, I0, I0, I0), + make_tuple(I0, n0, k0_inner, I0, I0, I0, I0), b_thread_buf); }); } @@ -592,18 +606,13 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto n0) { b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, - n0, - I0, - I0, - I0, - I0), + make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0), b_block_buf, b_scale_struct.b_scale_thread_bufs(I0)[Number< n0 * BScaleStruct::num_scale_k_block + (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}], b_thread_desc_, - make_tuple(I0, n0, k0_inner, I0, I0, I0), + make_tuple(I0, n0, k0_inner, I0, I0, I0, I0), b_thread_buf); }); } @@ -622,62 +631,69 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_inner) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0_inner, - I0, - I0, - Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0_inner, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0_inner, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard. + // B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts. + // It is performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0_offset + k0_inner == KRepeat - 1 && + m0 == MRepeat - 1 && n0 == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0_inner, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard. - // B) reduce VMEM FIFO congestion by applying small delays to - // different wavefronts. - // It is performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 && - n0 == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } }); }); }); @@ -729,12 +745,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); static constexpr auto b_thread_desc_ = @@ -743,12 +761,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); using AThreadCopy = @@ -756,9 +776,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, A_K1, A_K1>; @@ -767,9 +787,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, B_K1, B_K1>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index 83dadb2175..cbe13b6e00 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -32,6 +32,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v3 { @@ -55,6 +56,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v3 : BlockwiseGemmWmmaops_pipeline_base { using Base = BlockwiseGemmWmmaops_pipeline_base; using Base::I0; @@ -290,40 +295,37 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0), - a_thread_buf); + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); }); if constexpr(ck::is_same_v) { static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, k0, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); }); } else { static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_scale_struct.b_scale_thread_bufs( - I0)[Number{}], - b_thread_desc_, - make_tuple(I0, n0, k0, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); }); } }); @@ -364,6 +366,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -424,41 +429,48 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -489,31 +501,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, m0, k0, I0, I0, Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, n0, k0, I0, I0, Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -531,31 +559,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, m0, k0, I0, I0, Number{}))>{}]; + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, n0, k0, I0, I0, Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 465952e285..23f16d38e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -17,6 +17,9 @@ template {}, KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + if constexpr(KInner > 1) + { + // KPack = KInner * KPerWmma + // K1 = KInner * KPerWmmaBlk + // Each thread loads multiple tiles with one instruction + // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_unmerge_transform(make_tuple(Number{}, KRow, Number<1>{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // KPack = KPerWmma (KInner == 1) + if constexpr(ABK1 <= KPerWmmaBlk) + { + // K1 <= single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1 + // (rest of the single WMMA tile for single thread) and then over KRow + // (rest of the single WMMA tile for single wave) + // KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_unmerge_transform(make_tuple( + Number{}, KRow, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // K1 > single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 > single tile, each thread loads KPerWmmaBlk and the next + // KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout). + // This layout is needed to support for example AK1 > single tile and + // BK1 <= single tile in the same gemm + // KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - + // K1 + constexpr auto desc1 = transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{})); + + return transform_tensor_descriptor( + desc1, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_merge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{})); + } + } } __device__ static constexpr auto GetBlockStep() diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index 68476ef3bf..a36ccd43ca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -313,14 +313,16 @@ struct ABTransferWaveTiles // This is a block descriptor used to read LDS memory into register // It's defined in a way consistent with the existing implementation to // avoid changes in the pipelines - return make_naive_tensor_descriptor(make_tuple(Number{}, + return make_naive_tensor_descriptor(make_tuple(I1, Number{}, + Number{}, Number{}, Number{}, Number{}, Number{}), - make_tuple(Number{}, + make_tuple(I0, Number{}, + Number{}, Number{}, Number{}, Number{}, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index fa7eb4faaa..38ebdab65e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -109,9 +109,20 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); - // TODO: I am pretty sure this is always 16 and *should* always be 16. - static constexpr auto KPack = - math::integer_least_multiple(math::integer_least_multiple(AK1Value, BK1Value), 16); + static constexpr index_t KPerWmmaBlk = + WmmaSelector::selected_wmma + .k_per_blk; + + static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk); + + static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk); + + static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB); + + static constexpr index_t KPack = + KInner * + WmmaSelector::selected_wmma + .k_per_wmma; using ThisThreadBlock = ThisThreadBlock; @@ -201,54 +212,115 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 return b1_block_copy_step; } + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) + { + // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1 + constexpr auto K0 = BlockDesc{}.GetLength(I0); + constexpr auto K1 = BlockDesc{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto KRow = I2; +#else + constexpr auto KRow = I1; +#endif + + if constexpr(KInner > 1) + { + // KPack = KInner * KPerWmma + // K1 = KInner * KPerWmmaBlk + // Each thread loads multiple tiles with one instruction + // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_unmerge_transform(make_tuple(Number{}, KRow, Number<1>{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // KPack = KPerWmma (KInner == 1) + if constexpr(K1 <= KPerWmmaBlk) + { + // K1 <= single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1 + // (rest of the single WMMA tile for single thread) and then over KRow + // (rest of the single WMMA tile for single wave) + // KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple(make_unmerge_transform(make_tuple( + Number{}, KRow, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // K1 > single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 > single tile, each thread loads KPerWmmaBlk and the next + // KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout). + // This layout is needed to support for example AK1 > single tile and + // BK1 <= single tile in the same gemm + // KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - + // K1 + constexpr auto desc1 = transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{})); + + return transform_tensor_descriptor( + desc1, + make_tuple(make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_merge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{})); + } + } + } + template __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) { - constexpr auto a_wave_desc = [&]() { - // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); -#ifdef __gfx12__ - constexpr auto A_KRow = I2; -#else - constexpr auto A_KRow = I1; -#endif - return transform_tensor_descriptor( - ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); - }(); - - return a_wave_desc; + return MakeWmmaTileDescriptor(ABlockDesc_{}); } template __host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&) { - constexpr auto b0_wave_desc = [&]() { - // BK0_L_BK1 -> BK0_LRepeat_Lwaves_BKRow_LPerWmma_BK1 - constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); -#ifdef __gfx12__ - constexpr auto B_KRow = I2; -#else - constexpr auto B_KRow = I1; -#endif - return transform_tensor_descriptor( - B0BlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); - }(); - - return b0_wave_desc; + return MakeWmmaTileDescriptor(B0BlockDesc_{}); } template @@ -356,6 +428,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 MRepeat, LRepeat, KPack, + KInner, true>())>; // TransposeC (must be true to work), C' = B' x A' // block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 7a5e324468..56f09cee96 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -151,10 +151,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; - static constexpr index_t KPack = math::max( - math::lcm(AK1Number, BK1Number), + static constexpr index_t KPerWmmaBlk = WmmaSelector::selected_wmma - .k_per_wmma); + .k_per_blk; + + static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk); + + static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk); + + static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB); + + static constexpr index_t KPack = + KInner * + WmmaSelector::selected_wmma + .k_per_wmma; using ThisThreadBlock = ThisThreadBlock; @@ -218,6 +228,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base KPerBlock, MPerWmma, AK1Value, + KPack, + KInner, + KPerWmmaBlk, UseBlockPaddingA, PermuteA, ABlockTransferThreadClusterLengths_AK0_M_AK1, @@ -251,6 +264,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base KPerBlock, NPerWmma, BK1Value, + KPack, + KInner, + KPerWmmaBlk, UseBlockPaddingB, PermuteB, BBlockTransferThreadClusterLengths_BK0_N_BK1, @@ -563,7 +579,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base NPerWmma, MRepeat, NRepeat, - KPack>())>; + KPack, + KInner>())>; // Used to create obj in global function and pass it to Run method using EpilogueCShuffle = diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index bca68764f9..55ede990af 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -95,6 +95,7 @@ struct wmma_type __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index f29f3eeed6..e3b5c96d91 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -24,16 +24,43 @@ template auto shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } } template @@ -55,21 +82,46 @@ template auto shuffle_b_permuteN(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; - - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, - GemmConfig::N_Warp, - GemmConfig::N_Warp_Tile, - NRepeat, - k_ / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile, + NRepeat, + k_ / GemmConfig::K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, + GemmConfig::N_Warp, + GemmConfig::N_Warp_Tile, + NRepeat, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); + } } } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index 90f6204ff3..dd2931f6b7 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -79,6 +79,7 @@ struct WarpGemmAttributeWmma static constexpr index_t kM = Impl::kM; static constexpr index_t kN = Impl::kN; static constexpr index_t kK = Impl::kK; + static constexpr index_t kCMLane = Impl::kCMLane; static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 3273131875..3e16d937cb 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp b/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp new file mode 100644 index 0000000000..d695888b88 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Common utilities for quantized GEMM block operations +template +struct BlockGemmQuantCommon +{ + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemmType::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index df55081b69..6422c07e1d 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" namespace ck_tile { @@ -81,11 +82,11 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg float scale_reg_f = 0.f; if constexpr(std::is_same_v) { - scale_reg_f = element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { - scale_reg_f = element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { @@ -100,21 +101,8 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - return c_block_tensor; + return BlockGemmQuantCommon:: + MakeCBlockTile(); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 8b95ec6ddf..bbdd3128bf 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" namespace ck_tile { @@ -24,13 +25,11 @@ struct BlockGemmAQuantBase float scale_reg_f = 0.f; if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { @@ -348,7 +347,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase // Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, // 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3. - constexpr uint32_t kTileRowsOfCPerThread = 4; + constexpr uint32_t kTileRowsOfCPerThread = (get_warp_size() == 64) ? 4 : 8; decltype(threadIdx.x) pull_from_lane = 0; if constexpr(WarpGemm::kM == 16) { @@ -409,7 +408,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase // desired row coefficient auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset]; - constexpr uint32_t kTileRows = 4; + constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8; + ; constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows; constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane; // Multiply by 4 because output is stored in tiles of 4 @@ -543,20 +543,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - - return c_block_tensor; + return BlockGemmQuantCommon:: + MakeCBlockTile(); } template diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 9db444b57f..28ae709bf0 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" namespace ck_tile { @@ -24,13 +25,11 @@ struct BlockGemmBQuantBase float scale_reg_f = 0.f; if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { - scale_reg_f = - ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast(scale)); + scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast(scale), 0); } else if constexpr(std::is_same_v) { @@ -376,20 +375,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - - return c_block_tensor; + return BlockGemmQuantCommon:: + MakeCBlockTile(); } template diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 36cbb87877..15d2727f3b 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -240,7 +240,10 @@ struct QuantGemmKernel return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static auto BlockSize() + { + return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize); + } CK_TILE_HOST static constexpr QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs& hostArgs) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp index c4429b76f9..3a5b86382d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp @@ -41,7 +41,8 @@ template + bool UsePersistentKernel_ = false, + int VectorSize_ = 16> struct TileGemmQuantTraits { static constexpr bool kPadM = kPadM_; @@ -50,7 +51,7 @@ struct TileGemmQuantTraits static constexpr QuantType kQuantType = QuantType_; - static constexpr int _VectorSize = 16; + static constexpr int _VectorSize = VectorSize_; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; using ALayout = ALayout_; diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp index 71b5c5e7cf..806b6e684d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -48,7 +48,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp index f4489dc45f..4516d06492 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -50,7 +50,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp index 423f86365c..5ace0594f0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -53,7 +53,9 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp index 2eb28958e6..27deab1c8c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp @@ -56,7 +56,9 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp index d10b9facd5..bd5c7d8783 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp @@ -48,7 +48,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp index d9d16ede65..1956d1a951 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp @@ -49,7 +49,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp index 9277e5e901..934c6aa7ef 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp index e97a649c19..9860b81b78 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp index c8f1b85ddb..4d7169565a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp @@ -49,7 +49,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp index fc0220a502..3728368bc4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp index b87cf64b0f..3506575f5d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp index 31ad66409e..eef0d6de6a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp @@ -50,7 +50,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp index 4c37c398fe..2418be62b7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp @@ -55,7 +55,8 @@ using device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, F8, F8> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp index 6b5314b701..38f2869303 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, F8, F8> // clang-format on >; } // namespace instance diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 3a49e69c37..1c4a25c8bd 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -5,7 +5,7 @@ endif() list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") # Typed Test Suite for GEMM Quantization add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 6454101daf..6226a2de9e 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -69,7 +69,15 @@ class TestCkTileGemmQuantBase : public ::testing::Test constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; - + // WP pipeline requires per-thread tile size aligned to Problem::VectorLoadSize. + // static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) % + // VectorLoadSize == 0). gfx9 cards match the requirements but it fails on gfx12. so we only + // need to check the limitation on RDNA cards, i.e. assume wave size is 32. + constexpr ck_tile::index_t WaveSize = 32; + constexpr ck_tile::index_t MIterPerWarp = M_Tile / (M_Warp * M_Warp_Tile); + constexpr bool SupportVectorSize16 = + (M_Warp_Tile * K_Warp_Tile * sizeof(ADataType) * MIterPerWarp / WaveSize) % 16 == 0; + constexpr int VectorSize = PreshuffleB ? (SupportVectorSize16 ? 16 : 8) : 16; using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, @@ -89,7 +97,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test ALayout, BLayout, GemmConfig::TransposeC, - DoubleSmemBuffer>; + DoubleSmemBuffer, + false, + VectorSize>; // Let the derived class create the appropriate pipeline and epilogue static_cast(this) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index cabc0ec02c..5aac095514 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -7,6 +7,16 @@ #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else + return is_8bit ? 64 : 32; +#endif +} + struct GemmConfigBase { static constexpr bool kPadM = false; @@ -40,7 +50,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigPreshuffleQuant : public GemmConfigBase @@ -75,7 +85,7 @@ struct GemmConfigPreshuffleBDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigPreshuffleBPrefill : public GemmConfigBase @@ -94,7 +104,7 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 64; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill @@ -132,7 +142,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase #include diff --git a/tile_engine/ops/gemm/gemm_common.hpp b/tile_engine/ops/gemm/gemm_common.hpp index 4732f2a1ba..899221547f 100644 --- a/tile_engine/ops/gemm/gemm_common.hpp +++ b/tile_engine/ops/gemm/gemm_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 1aff42b902..8885c821c1 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os import json diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index 575e5240a8..3c6bbc34d3 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp index 53dcdb5e1f..f8c196e32a 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py index fb81b9c2c2..044e08baca 100755 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. import sys import json diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp index 032a625354..41d2f736e1 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark_single.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp index 4732f2a1ba..899221547f 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py index 3f7858f146..cc167fb75f 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os import json diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp index 8e19c11c7d..3a2cdc71fe 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py index b38ff5dffb..70ce3b0d72 100644 --- a/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py +++ b/tile_engine/ops/gemm_preshuffle/commons/validation_utils.py @@ -1,6 +1,6 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. """ Validation utilities for GEMM kernel generation. diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp index 77a9f26527..748fe581d3 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.hpp @@ -1,3 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once #include "ck_tile/core.hpp" diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py index 0217a439f2..d8892be7d6 100755 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. import sys import json diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp index 1f03d1cf9b..4fbb25f0c9 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp index abaa5ebd46..1b2cfe3735 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py index 57c250f57e..9ce6d8cb25 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_instance_builder.py @@ -1,5 +1,5 @@ -## Copyright © Advanced Micro Devices, Inc. or its affiliates. -## SPDX-License-Identifier: MIT +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT import argparse import os diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp index 85b731c231..739bd7e677 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp @@ -1,4 +1,4 @@ -// Copyright © Advanced Micro Devices, Inc. or its affiliates. +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once