From 82b5464c67a7fabbbbdbdfe01ba5faddb1b9839d Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Mon, 5 Jan 2026 13:55:41 +0000 Subject: [PATCH] fixup! Add DstDataType as a template parameter to load_tile_with_elementwise, and use it for type conversion --- include/ck_tile/core/tensor/tile_window.hpp | 26 ++++++++++++++----- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 17 +++++++----- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 009d60c1e9..a11f30040a 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -17,6 +17,7 @@ #include "ck_tile/core/tensor/tile_window_base.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace ck_tile { @@ -194,13 +195,26 @@ struct tile_window_with_static_distribution bool_constant = {}) const { constexpr auto tile_dstr = typename Base::TileDstr{}; - auto dst_tensor = make_static_distributed_tensor(tile_dstr); + auto dst_tensor = make_static_distributed_tensor(tile_dstr); load(dst_tensor, - tile_windows, - elementwise, - number{}, - bool_constant{}); - return dst_tensor; + tile_windows, + elementwise, + number{}, + bool_constant{}); + + if constexpr(std::is_same_v) + { + return dst_tensor; + } + else + { + auto ret = make_static_distributed_tensor(tile_dstr); + sweep_tile(ret, [&](auto i) { + element_wise::PassThrough pass_through{}; + pass_through(ret(i), dst_tensor(i)); + }); + return ret; + } } template using ADataType = remove_cvref_t>; using BDataType = remove_cvref_t>; + using ATypeToUse = typename DetermineWarpPrecType::prec_type; + using BTypeToUse = typename DetermineWarpPrecType::prec_type; + using BlockGemm = remove_cvref_t())>; using I0 = number<0>; using I1 = number<1>; @@ -437,7 +440,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // Definitions of all needed tiles // A/B tiles in LDS - auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + auto&& [a_lds_block, b_lds_block] = Base::template GetABLdsTensorViews(p_smem); // Tile distribution for load from lds constexpr auto a_lds_load_tile_distr = @@ -477,7 +480,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // Load tile — during value loading, an elementwise function is executed for each A0, // A1, … AN. The values A0, A1, … AN are read by the same thread. auto elementwise_As_res = - load_tile_with_elementwise(a_copy_dram_window, a_element_func); + load_tile_with_elementwise(a_copy_dram_window, a_element_func, ATypeToUse{}); // Move each A — the enhanced function move_tile_window is executed, which takes a tuple // as input. @@ -486,7 +489,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // Load tile — during value loading, an elementwise function is executed for each B0, // B1, … BN. The values B0, B1, … BN are read by the same thread. auto elementwise_Bs_res = - load_tile_with_elementwise(b_copy_dram_window, b_element_func); + load_tile_with_elementwise(b_copy_dram_window, b_element_func, BTypeToUse{}); // Move each B — the enhanced function move_tile_window is executed, which takes a tuple // as input. @@ -518,10 +521,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // global read 1 - elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func, ATypeToUse{}); move_tile_window(a_copy_dram_window, a_dram_tile_window_step); - elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func, BTypeToUse{}); move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); @@ -562,11 +565,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 } elementwise_As_res = - load_tile_with_elementwise(a_copy_dram_window, a_element_func); + load_tile_with_elementwise(a_copy_dram_window, a_element_func, ATypeToUse{}); move_tile_window(a_copy_dram_window, a_dram_tile_window_step); elementwise_Bs_res = - load_tile_with_elementwise(b_copy_dram_window, b_element_func); + load_tile_with_elementwise(b_copy_dram_window, b_element_func, BTypeToUse{}); move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);