From 9bbf7016b6fad5ab07452f43a898c5342fa18c62 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 13 Oct 2025 12:17:18 +0000 Subject: [PATCH] Merge commit '634634f5c09a3b42f5f838a5af9c948602e246db' into develop --- example/ck_tile/03_gemm/gemm_utils.hpp | 40 +- include/ck_tile/host.hpp | 1 + include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 2 + include/ck_tile/ops/batched_transpose.hpp | 2 + include/ck_tile/ops/common.hpp | 3 +- .../ops/common/load_interleaved_pk_type.hpp | 2 +- include/ck_tile/ops/elementwise.hpp | 2 + include/ck_tile/ops/epilogue.hpp | 2 + include/ck_tile/ops/flatmm.hpp | 2 + include/ck_tile/ops/fmha.hpp | 2 + include/ck_tile/ops/fused_moe.hpp | 2 + include/ck_tile/ops/gemm.hpp | 10 +- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 37 +- .../gemm_pipeline_ag_bg_cr_comp_v6.hpp | 770 ++++++++++++++++++ ...peline_ag_bg_cr_comp_v6_default_policy.hpp | 56 ++ include/ck_tile/ops/gemm_quant.hpp | 8 +- include/ck_tile/ops/grouped_convolution.hpp | 2 + include/ck_tile/ops/image_to_column.hpp | 2 + include/ck_tile/ops/layernorm2d.hpp | 2 + include/ck_tile/ops/norm_reduce.hpp | 2 + include/ck_tile/ops/permute.hpp | 2 + include/ck_tile/ops/reduce.hpp | 2 + include/ck_tile/ops/rmsnorm2d.hpp | 2 + include/ck_tile/ops/smoothquant.hpp | 2 + include/ck_tile/ops/softmax.hpp | 2 + include/ck_tile/ops/topk.hpp | 2 + include/ck_tile/ops/topk_softmax.hpp | 2 + test/ck_tile/gemm/CMakeLists.txt | 10 +- .../gemm/test_gemm_pipeline_compv6.cpp | 17 + .../gemm/test_gemm_pipeline_kernel_types.hpp | 23 + test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 18 +- tile_engine/ops/gemm/gemm_common.hpp | 52 -- tile_engine/ops/gemm/gemm_profiler.hpp | 2 +- 33 files changed, 1011 insertions(+), 74 deletions(-) create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_compv6.cpp diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 07b925d0eb..a831a4f26c 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -16,8 +16,9 @@ #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 -#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5 -#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6 +#define CK_TILE_PIPELINE_COMPUTE_V6 5 +#define CK_TILE_PIPELINE_PRESHUFFLE_V1 6 +#define CK_TILE_PIPELINE_PRESHUFFLE_V2 7 template constexpr ck_tile::index_t get_k_warp_tile() @@ -251,9 +252,29 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; - static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaveGroups = 2; +}; + +template +struct GemmConfigComputeV6 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6; + static constexpr ck_tile::index_t NumWaveGroups = 1; }; template @@ -484,6 +505,15 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; }; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6; +}; + template <> struct PipelineTypeTraits { diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 86110d57ec..d815b1db40 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -16,6 +16,7 @@ #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/joinable_thread.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/ranges.hpp" #include "ck_tile/host/reference/reference_batched_dropout.hpp" #include "ck_tile/host/reference/reference_batched_dropout_randval.hpp" diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 1768c802d5..6c0972e10a 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -9,5 +9,7 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index ca0088c812..5822d7b91b 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -12,5 +12,7 @@ #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 7c6adc3ec2..eff2d625b3 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" -#include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp index f8432b9da0..fb7a05044f 100644 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core/config.hpp" -#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace ck_tile { diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 4858245ec4..7f2303932e 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -10,5 +10,7 @@ #include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 6cc0fa8540..ec5a8ef445 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -8,5 +8,7 @@ #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 1714789e63..41463e6a2d 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -14,5 +14,7 @@ #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 31de21a726..6b25c089bd 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -60,5 +60,7 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index ddb64a2189..71721f3408 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -16,5 +16,7 @@ #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 5edde31cd9..2a4f9d21e3 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -30,20 +30,22 @@ #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" @@ -72,5 +74,7 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 6f9d53467f..806a471397 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -161,8 +161,43 @@ struct BatchedGemmKernel } CK_TILE_HOST static auto - IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool + IsSupportedArgument(const typename BatchedGemmKernel::KernelArgs& kargs) -> bool { + if(kargs.batch_count < 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met: batch_count must be at least 1 !"); + } + return false; + } + if(kargs.batch_stride_A < 0 || kargs.batch_stride_A < kargs.M * kargs.K) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conditions not met: batch_stride_A must be non-negative and at least K * M!"); + } + return false; + } + if(kargs.batch_stride_B < 0 || kargs.batch_stride_B < kargs.K * kargs.N) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conditions not met: batch_stride_B must be non-negative and at least K * N!"); + } + return false; + } + if(kargs.batch_stride_E < 0 || kargs.batch_stride_E < kargs.M * kargs.N) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conditions not met: batch_stride_E must be non-negative and at least M * N!"); + } + return false; + } return UniversalGemmKernel::IsSupportedArgument(kargs); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp new file mode 100644 index 0000000000..2ae9001098 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp @@ -0,0 +1,770 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BaseGemmPipelineAgBgCrCompV6 +{ + static constexpr index_t PrefetchStages = 3; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t HotloopUnroll = 2; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + if(num_loop % HotloopUnroll == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + // Handle all the valid cases. + if(has_hot_loop) + { + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Even) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + else + { + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Even) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + // If execution reaches here, it's an invalid tail_number because it wasn't handled above. +#if defined(__HIP_DEVICE_COMPILE__) + __builtin_unreachable(); +#else + throw std::logic_error("Invalid TailNumber: Only TailNumber::Odd and TailNumber::Even are " + "supported in this pipeline context."); +#endif + } +}; + +// Compute optimized pipeline +// GlobalPrefetchStages: 3 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 2 +template +struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6 +{ + using Base = BaseGemmPipelineAgBgCrCompV6; + using BasePImpl = GemmPipelineAgBgCrImplBase; + + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + static_assert(!std::is_same_v, "Not implemented"); + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + using BlockGemm = remove_cvref_t())>; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr index_t KRepeat = BlockGemm::WarpGemm::kKPerThread / GetSmemPackA(); + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t Preshuffle = Problem::Preshuffle; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AgBgCrCompV6", BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK), + concat('x', TailNum), + concat('_', KRepeat), + concat('_', DoubleSmemBuffer), + concat('_', Preshuffle), + concat('_', HasHotLoop)); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Policy::template IsTransposeC(); + } + + template + struct PipelineImpl : public BasePImpl + { + }; + + template <> + struct PipelineImpl : public BasePImpl + { + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0); + constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1); + constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2); + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1); + + constexpr index_t A_LDS_Read_Width = KPerXDL; + constexpr index_t B_LDS_Read_Width = KPerXDL; + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + + constexpr auto ds_read_a_issue_cycle = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4; + + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1); + constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1); + constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat; + constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat; + + constexpr auto num_dsread_stage1_a_mfma = + (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_stage1_b_mfma = + (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + constexpr auto num_dsread_stage3_a_mfma = + (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_stage3_b_mfma = + (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + constexpr auto num_mfma_stage2 = C_MFMA_Inst_Num - + num_ds_read_inst_a / ds_read_a_mfma_rate - + num_ds_read_inst_b / ds_read_b_mfma_rate; + constexpr auto num_mfma_per_issue = + num_mfma_stage2 / (A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num); + constexpr auto num_dswrite_per_issue_a = A_LDS_Write_Inst_Num / A_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_b = B_LDS_Write_Inst_Num / B_Buffer_Load_Inst_Num; + + // stage 1 + static_for<0, num_dsread_stage1_a_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, num_dsread_stage1_b_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + // stage 2 + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 3 + static_for<0, num_dsread_stage3_a_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, num_dsread_stage3_b_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_barrier(0); + } + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem) const + { + // TODO: Add Multi A/B support + static_assert(std::tuple_size>::value == 1, + "Multi A/B is not yet supported for this pipeline."); + static_assert(std::tuple_size>::value == 1, + "Multi A/B is not yet supported for this pipeline."); + + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( + std::is_same_v> && + std::is_same_v>, + "Data Type conflict on A and B matrix input data type."); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1]), + "B block window has incorrect lengths for defined BLayout!"); + + ////////////// LDS desc, window & register ///////////////// + using ALdsType = + remove_cvref_t; + using BLdsType = + remove_cvref_t; + auto&& ABLdsTensorViews = BasePImpl::GetABLdsTensorViews(p_smem); + ALdsType& a_lds_block = ABLdsTensorViews.at(I0); + BLdsType& b_lds_block = ABLdsTensorViews.at(I1); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + using acopy_dram_type = + remove_cvref_t; + using bcopy_dram_type = + remove_cvref_t; + + using a_copy_lds_window_type = + remove_cvref_t; + using b_copy_lds_window_type = + remove_cvref_t; + + using a_lds_load_tile_distr_type = + remove_cvref_t; + using b_lds_load_tile_distr_type = + remove_cvref_t; + + auto&& aWindows = + BasePImpl::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + auto&& bWindows = + BasePImpl::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + acopy_dram_type& a_copy_dram_window = aWindows.at(I0); + a_copy_lds_window_type& a_copy_lds_window = aWindows.at(I1); + a_lds_load_tile_distr_type& a_lds_gemm_window = aWindows.at(I2); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + bcopy_dram_type& b_copy_dram_window = bWindows.at(I0); + b_copy_lds_window_type& b_copy_lds_window = bWindows.at(I1); + b_lds_load_tile_distr_type& b_lds_gemm_window = bWindows.at(I2); + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + + using ABlockTileDistr = + decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile[Base::GlobalBufferNum]; + BBlockTile b_block_tile[Base::GlobalBufferNum]; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + + constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_lds_tile; + BLdsTile b_lds_tile; + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + + // Global prefetch 1 + a_block_tile[I0] = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + b_block_tile[I0] = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // Local prefill 1 + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile[I0]); + BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(a_copy_lds_window, a_block_tile[I0]); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile[I0]); + BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(b_copy_lds_window, b_block_tile[I0]); + } + + // Global prefetch 2 + a_block_tile[I0] = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + b_block_tile[I0] = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + // Global prefetch 3 + a_block_tile[I1] = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + b_block_tile[I1] = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + + // Local prefetch 1 + BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v); + BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v); + + if(HasHotLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto vmem_buf_idx) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + if constexpr(k0 == (KRepeat - 1)) + { + block_sync_lds(); + + // Local prefill 2 + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution< + Problem>()); + transpose_tile2d(a_shuffle_tmp, a_block_tile[vmem_buf_idx]); + BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(a_copy_lds_window, + a_block_tile[vmem_buf_idx]); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution< + Problem>()); + transpose_tile2d(b_shuffle_tmp, b_block_tile[vmem_buf_idx]); + BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(b_copy_lds_window, + b_block_tile[vmem_buf_idx]); + } + + // Global prefetch 4 + a_block_tile[vmem_buf_idx] = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + b_block_tile[vmem_buf_idx] = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + } + block_gemm(c_block_tile, a_lds_tile, b_lds_tile); + + // Local prefetch 2 + BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v); + BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v); + }); + + HotLoopScheduler(); + }; + + LoopFunc(I0); + LoopFunc(I1); + + i += Base::HotloopUnroll; + } while(i < (num_loop - Base::PrefetchStages)); + } + + auto ReadWriteCompFunc = [&](auto vmem_buf_idx) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + if constexpr(k0 == (KRepeat - 1)) + { + block_sync_lds(); + + // Local prefill 3 + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile[vmem_buf_idx]); + BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(a_copy_lds_window, a_block_tile[vmem_buf_idx]); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile[vmem_buf_idx]); + BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(b_copy_lds_window, b_block_tile[vmem_buf_idx]); + } + + block_sync_lds(); + } + + block_gemm(c_block_tile, a_lds_tile, b_lds_tile); + + BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v); + BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v); + }); + + HotLoopScheduler(); + }; + + auto ReadCompFunc = [&]() { + static_for<0, KRepeat - 1, 1>{}([&]() { + __syncthreads(); + block_gemm(c_block_tile, a_lds_tile, b_lds_tile); + + // Local prefetch 4 + BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v); + BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v); + + __syncthreads(); + }); + + block_gemm(c_block_tile, a_lds_tile, b_lds_tile); + + HotLoopScheduler(); + }; + + if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I0); + ReadWriteCompFunc(I1); + ReadCompFunc(); + } + else if constexpr(TailNum == TailNumber::Even) + { + ReadWriteCompFunc(I0); + ReadCompFunc(); + } + + return c_block_tile; + } + }; + + public: + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + } + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](auto& e, const ADataType& a) { e = a; }, + b_dram_block_window_tmp, + [](auto& e, const BDataType& b) { e = b; }, + num_loop, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp new file mode 100644 index 0000000000..6ac702d38b --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { +// Default policy for GemmPipelineAGmemBGmemCregComputeV6, except the block gemm method, it shares +// the same vector size implementation, SmemSize, Global memory tile distiribution as the +// UniversalGemm Pipeline Policy. +// Default policy class should not be templated, put template on +// member functions instead. +struct GemmPipelineAgBgCrCompV6DefaultPolicy + : public UniversalGemmBasePolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + constexpr index_t vector_size = + DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); + constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); + constexpr auto wg_attr_num_access = + !(is_a_load_tr || is_b_load_tr) ? WGAttrNumAccessEnum::Single + : vector_size == thread_elements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; + + using WarpGemm = WarpGemmDispatcher; + + using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index cde0b6833f..61cb96c8f4 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -3,8 +3,8 @@ #pragma once -#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp" +#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp" #include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" @@ -15,11 +15,13 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 09b50f26b0..1dd13b6246 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -12,5 +12,7 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 93664ea138..2307b05190 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -7,5 +7,7 @@ #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index afbb817db1..9ce22137bf 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -10,5 +10,7 @@ #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index 7dc3e8b7e7..aa074b7f9f 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -7,5 +7,7 @@ #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index 1cc3d9cbc3..46512c57fe 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -6,5 +6,7 @@ #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index a6721c9305..d628e9c945 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -11,5 +11,7 @@ #include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 610541b2e4..00afcf4aed 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -11,5 +11,7 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index dc164dc1a0..1aa14c69e1 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -10,5 +10,7 @@ #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index b23e869d81..d559dc15e2 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -6,5 +6,7 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 1dc563f757..040c6b8ddc 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -6,5 +6,7 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index d0a810de4f..d9657a9764 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -8,5 +8,7 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 1ca7f4fc7d..24cc1bc5ab 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -24,12 +24,13 @@ endif() if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_test_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_universal_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_fp8 test_gemm_pipeline_basic_fp8.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp) + + target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_universal_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") @@ -55,10 +56,13 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12") add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv6 test_gemm_pipeline_compv6.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv6 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() if(GPU_TARGETS MATCHES "gfx95") diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv6.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv6.cpp new file mode 100644 index 0000000000..a72ff98055 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv6.cpp @@ -0,0 +1,17 @@ +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileGemmPipelineCompV6 + : public TestCkTileGemmPipeline> +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV6 + +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV6, KernelTypesCompV6); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index bba106174c..aa1f610022 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -29,6 +29,7 @@ using Interwave = ck_tile::integral_constant; using CompV3 = ck_tile::integral_constant; using CompV4 = ck_tile::integral_constant; +using CompV6 = ck_tile::integral_constant; using CompAsync = ck_tile::integral_constant; using Persistent = std::true_type; @@ -130,6 +131,28 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4> >; +using KernelTypesCompV6 = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6> +>; using KernelTypesCompAsync = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 01bc3d7522..994510c060 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -38,6 +38,7 @@ enum struct GemmPipelineType Mem, CompV3, CompV4, + CompV6, CompAsync }; @@ -71,6 +72,15 @@ struct GemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; } }; +template +struct GemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6; + using pipeline = ck_tile::GemmPipelineAgBgCrCompV6; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV6"; } +}; + template struct GemmPipelineTypeSelector { @@ -120,11 +130,13 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr bool kPadK = PadK; constexpr bool preshuffle = Preshuffle; - constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4 || + constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4 || PipelineType == GemmPipelineType::CompAsync); + constexpr bool TransposeC = false; + static constexpr bool StructuredSparsity = false; + static constexpr bool NumWaveGroup = 1; // TODO: For now - but this should also be a test parameter - constexpr bool TransposeC = false; constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; @@ -140,8 +152,6 @@ class TestCkTileGemmPipeline : public ::testing::Test GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; - static constexpr bool StructuredSparsity = false; - static constexpr bool NumWaveGroup = 1; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits>{}; } -// Permutation function for pk_int4_t -template -void permute_vectors_i4x4_b(Tensor& tensor) -{ - const ck_tile::index_t K = tensor.get_length(0); - const ck_tile::index_t N = tensor.get_length(1); - // vector pk_i4x4 permute - for(int i = 0; i < N; i++) - { - for(int j = 0; j < K; j += 8) - { - int8_t input[8]; - - for(int k = 0; k < 4; k++) - { - int8_t i4x2 = tensor(j + k * 2, i).data; - input[k * 2 + 0] = (i4x2 >> 4) & 0xf; - input[k * 2 + 1] = (i4x2 >> 0) & 0xf; - } - - // permute 01234567->20643175 - { - int8_t hi = input[2]; - int8_t lo = input[0]; - int8_t i4x2 = (hi << 4) | lo; - tensor(j + 0, i) = i4x2; - } - - { - int8_t hi = input[6]; - int8_t lo = input[4]; - int8_t i4x2 = (hi << 4) | lo; - tensor(j + 2, i) = i4x2; - } - - { - int8_t hi = input[3]; - int8_t lo = input[1]; - int8_t i4x2 = (hi << 4) | lo; - tensor(j + 4, i) = i4x2; - } - - { - int8_t hi = input[7]; - int8_t lo = input[5]; - int8_t i4x2 = (hi << 4) | lo; - tensor(j + 6, i) = i4x2; - } - } - } -} - // Structure to hold kernel traits for dispatcher struct KernelTraits { diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index bbf0c92e67..1298c78d18 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -96,7 +96,7 @@ class GemmProfiler // Permute vector pk_i4x4 data for device implementation ck_tile::HostTensor b_k_n_dev = b_k_n; // permute_tensor_b(b_k_n_dev); - permute_vectors_i4x4_b(b_k_n_dev); + ck_tile::permute_vectors_i4x4_b(b_k_n_dev); b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } else