From ae4444dfba108f54f4218aba2a16fc11c1dcb833 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Tue, 11 Nov 2025 07:42:26 -0800 Subject: [PATCH] formatting (#3182) [ROCm/composable_kernel commit: 06c651b100c9dc50753277069bdc68411da7ca1a] --- include/ck_tile/ops/gemm_quant.hpp | 1 + .../block/block_gemm_quant_common.hpp | 38 +++++++++++++++++++ ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 18 ++------- .../block_universal_gemm_as_aquant_bs_cr.hpp | 17 ++------- .../block_universal_gemm_as_bs_bquant_cr.hpp | 17 ++------- 5 files changed, 48 insertions(+), 43 deletions(-) create mode 100644 include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 3273131875..3e16d937cb 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp b/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp new file mode 100644 index 0000000000..d695888b88 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Common utilities for quantized GEMM block operations +template +struct BlockGemmQuantCommon +{ + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemmType::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index df55081b69..2d92745f75 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" namespace ck_tile { @@ -100,21 +101,8 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); - - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - return c_block_tensor; + return BlockGemmQuantCommon:: + MakeCBlockTile(); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 8b95ec6ddf..1f72f4dc12 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" namespace ck_tile { @@ -543,20 +544,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - - return c_block_tensor; + return BlockGemmQuantCommon:: + MakeCBlockTile(); } template diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 9db444b57f..660c30aa6e 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" namespace ck_tile { @@ -376,20 +377,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - - return c_block_tensor; + return BlockGemmQuantCommon:: + MakeCBlockTile(); } template