diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index d1c06d4378..e940ea2afd 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -48,19 +48,26 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, * and an elementwise function. For each A = A0, A1… AN, the elementwise function * is additionally applied during a single read. */ -template -CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple& tile_windows, - ElementWise_ elementwise, - number = {}, - bool_constant = {}) +CK_TILE_DEVICE auto +load_tile_with_elementwise(const ck_tile::tuple& tile_windows, + ElementWise_ elementwise, + DstDataType_ = {}, + number = {}, + bool_constant = {}) { // TODO: Tile windows should work with unknown number of params // Load element_wise API works only when the input type is a tuple-type - return tile_windows[number<0>{}].load( - tile_windows, elementwise, number{}, bool_constant{}); + return tile_windows[number<0>{}].load(tile_windows, + elementwise, + DstDataType_{}, + number{}, + bool_constant{}); } // Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index da90675fdd..009d60c1e9 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -184,15 +184,17 @@ struct tile_window_with_static_distribution */ template CK_TILE_DEVICE auto load(const ck_tile::tuple& tile_windows, ElementWise_ elementwise, + DstDataType_ = {}, number = {}, bool_constant = {}) const { constexpr auto tile_dstr = typename Base::TileDstr{}; - auto dst_tensor = make_static_distributed_tensor(tile_dstr); + auto dst_tensor = make_static_distributed_tensor(tile_dstr); load(dst_tensor, tile_windows, elementwise, @@ -208,7 +210,7 @@ struct tile_window_with_static_distribution bool oob_conditional_check = true> CK_TILE_DEVICE void load(DistributedTensor& dst_tensor, const ck_tile::tuple& tile_windows, - ElementWise_ elementwise, + ElementWise_ elementwise = {}, number = {}, bool_constant = {}) const {