From cfa11f2d1fad3cf19b0fbae601349bce4525024d Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Thu, 27 Nov 2025 08:35:18 +0000 Subject: [PATCH] Rename InterleavedPKTypeLoader to ConverterLoader, and load_int4_tile to load_and_convert_tile --- .../ops/common/load_and_convert_tile.hpp | 6 ++--- .../block/block_universal_gemm_as_bs_cr.hpp | 24 +++++++++---------- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 3 ++- ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 4 ++-- .../block_universal_gemm_as_aquant_bs_cr.hpp | 4 ++-- .../block_universal_gemm_as_bs_bquant_cr.hpp | 4 ++-- .../gemm_abquant_pipeline_ag_bg_cr_v3.hpp | 6 +++-- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 3 ++- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 3 ++- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 3 ++- .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 8 +++---- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 8 +++---- 12 files changed, 41 insertions(+), 35 deletions(-) diff --git a/include/ck_tile/ops/common/load_and_convert_tile.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp index 0da1f11229..f2ee23e98b 100644 --- a/include/ck_tile/ops/common/load_and_convert_tile.hpp +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -9,7 +9,7 @@ namespace ck_tile { template -struct InterleavedPKTypeLoader +struct ConverterLoader { template CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src) @@ -34,12 +34,12 @@ template -CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) +CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src) { if constexpr(std::is_same_v) { static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); - InterleavedPKTypeLoader::load_interleaved_pk_type(dst, src); + ConverterLoader::load_interleaved_pk_type(dst, src); } else if constexpr(LoadTranspose) { diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index a0fa732d1a..040051a5e8 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -228,10 +228,10 @@ struct BlockUniversalGemmAsBsCr "The ADataType and BDataType as defined in " "traits should be the same as correspoinding block window data type!"); - load_int4_tile(a_warp_tile_, - a_block_window); - load_int4_tile(b_warp_tile_, - b_block_window); + load_and_convert_tile( + a_warp_tile_, a_block_window); + load_and_convert_tile( + b_warp_tile_, b_block_window); // hot loop: static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -294,10 +294,10 @@ struct BlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_int4_tile(a_warp_tile_, - a_block_window); - load_int4_tile(b_warp_tile_, - b_block_window); + load_and_convert_tile( + a_warp_tile_, a_block_window); + load_and_convert_tile( + b_warp_tile_, b_block_window); } // C += A * B @@ -425,10 +425,10 @@ struct BlockUniversalGemmAsBsCr auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); - load_int4_tile(a_warp_tile_, - a_lds_gemm_window); - load_int4_tile(b_warp_tile_, - b_lds_gemm_window); + load_and_convert_tile( + a_warp_tile_, a_lds_gemm_window); + load_and_convert_tile( + b_warp_tile_, b_lds_gemm_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index dcc11015e7..74632ee5b0 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -74,7 +74,8 @@ struct GemmPipelineAgBgCrImplBase SrcTileWindow& dram_tile_window, const DramTileWindowStep& dram_tile_window_step) const { - load_int4_tile(dst_block_tile, dram_tile_window); + load_and_convert_tile(dst_block_tile, + dram_tile_window); move_tile_window(dram_tile_window, dram_tile_window_step); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index c44d330d13..132b31ed62 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -261,10 +261,10 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}, bool_constant = {}) { - load_int4_tile( + load_and_convert_tile( a_warp_tile_, a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_int4_tile( + load_and_convert_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 705a992b52..b40168b2af 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -248,9 +248,9 @@ struct AQuantBlockUniversalGemmAsBsCr // while ADatatype might not be the same as BDataType at the time of problem // initialization, we can safely use BDataType here because when A would be int4 we will // ensure A is converted to BDataType prior to loading - load_int4_tile( + load_and_convert_tile( a_warp_tile_, a_block_window); - load_int4_tile( + load_and_convert_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 313e449c7b..ece393b40d 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -258,10 +258,10 @@ struct BQuantBlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_int4_tile( + load_and_convert_tile( a_warp_tile_, a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_int4_tile( + load_and_convert_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index cd70c2ca86..b6fd25139e 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -200,7 +200,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, + a_dram_window); } template @@ -210,7 +211,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_dram_window); + load_and_convert_tile(b_block_tile, + b_dram_window); } template using DestDataType = typename ABlockTile_::DataType; using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; constexpr index_t UnaryOpSize = 8; - load_int4_tile(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, + a_dram_window); move_tile_window(a_dram_window, dram_tile_window_step); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 22dd78e070..0a4793bb12 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -171,7 +171,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, + a_dram_window); } template (b_block_tile, b_dram_window); + load_and_convert_tile(b_block_tile, + b_dram_window); } template diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 0f3951ffcc..49064bdb76 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -349,7 +349,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -430,7 +430,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -455,7 +455,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -503,7 +503,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index e4de7e4211..5455944de0 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -335,7 +335,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -421,7 +421,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -458,7 +458,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -516,7 +516,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); });