mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
fixup! Add DstDataType as a template parameter to load_tile_with_elementwise, and use it for type conversion
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -194,13 +195,26 @@ struct tile_window_with_static_distribution
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<DstDataType_>(tile_dstr);
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
load(dst_tensor,
|
||||
tile_windows,
|
||||
elementwise,
|
||||
number<i_access_unsupport_>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
tile_windows,
|
||||
elementwise,
|
||||
number<i_access_unsupport_>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
if constexpr(std::is_same_v<DstDataType_, typename Base::DataType>)
|
||||
{
|
||||
return dst_tensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto ret = make_static_distributed_tensor<DstDataType_>(tile_dstr);
|
||||
sweep_tile(ret, [&](auto i) {
|
||||
element_wise::PassThrough pass_through{};
|
||||
pass_through(ret(i), dst_tensor(i));
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
|
||||
@@ -115,6 +115,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using ATypeToUse = typename DetermineWarpPrecType<ADataType, BDataType>::prec_type;
|
||||
using BTypeToUse = typename DetermineWarpPrecType<BDataType, ADataType>::prec_type;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
@@ -437,7 +440,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
// Definitions of all needed tiles
|
||||
|
||||
// A/B tiles in LDS
|
||||
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
|
||||
auto&& [a_lds_block, b_lds_block] = Base::template GetABLdsTensorViews<ATypeToUse, BTypeToUse>(p_smem);
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
@@ -477,7 +480,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func, ATypeToUse{});
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
@@ -486,7 +489,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func, BTypeToUse{});
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
@@ -518,10 +521,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
// global read 1
|
||||
|
||||
elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func, ATypeToUse{});
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func, BTypeToUse{});
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
@@ -562,11 +565,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
}
|
||||
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func, ATypeToUse{});
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func, BTypeToUse{});
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
Reference in New Issue
Block a user