Add DstDataType as a template parameter to load_tile_with_elementwise, and use it for type conversion

This commit is contained in:
Sami Aario
2025-12-15 13:41:17 +00:00
parent 2e798d15e1
commit d75d38bf05
2 changed files with 18 additions and 9 deletions

View File

@@ -48,19 +48,26 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
* is additionally applied during a single read.
*/
template <typename... TileWindow_,
template <typename FirstTileWindow_,
typename... RestTileWindow_,
typename ElementWise_,
typename DstDataType_ = typename FirstTileWindow_::Base::DataType,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
CK_TILE_DEVICE auto
load_tile_with_elementwise(const ck_tile::tuple<FirstTileWindow_, RestTileWindow_...>& tile_windows,
ElementWise_ elementwise,
DstDataType_ = {},
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
// TODO: Tile windows should work with unknown number of params
// Load element_wise API works only when the input type is a tuple-type
return tile_windows[number<0>{}].load(
tile_windows, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
return tile_windows[number<0>{}].load(tile_windows,
elementwise,
DstDataType_{},
number<i_access>{},
bool_constant<oob_conditional_check>{});
}
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.

View File

@@ -184,15 +184,17 @@ struct tile_window_with_static_distribution
*/
template <typename... TileWindow_,
typename ElementWise_,
typename DstDataType_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
DstDataType_ = {},
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
auto dst_tensor = make_static_distributed_tensor<DstDataType_>(tile_dstr);
load(dst_tensor,
tile_windows,
elementwise,
@@ -208,7 +210,7 @@ struct tile_window_with_static_distribution
bool oob_conditional_check = true>
CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
ElementWise_ elementwise = {},
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{