mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Add DstDataType as a template parameter to load_tile_with_elementwise, and use it for type conversion
This commit is contained in:
@@ -48,19 +48,26 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
|
||||
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
|
||||
* is additionally applied during a single read.
|
||||
*/
|
||||
template <typename... TileWindow_,
|
||||
template <typename FirstTileWindow_,
|
||||
typename... RestTileWindow_,
|
||||
typename ElementWise_,
|
||||
typename DstDataType_ = typename FirstTileWindow_::Base::DataType,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple<TileWindow_...>& tile_windows,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
CK_TILE_DEVICE auto
|
||||
load_tile_with_elementwise(const ck_tile::tuple<FirstTileWindow_, RestTileWindow_...>& tile_windows,
|
||||
ElementWise_ elementwise,
|
||||
DstDataType_ = {},
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
// TODO: Tile windows should work with unknown number of params
|
||||
// Load element_wise API works only when the input type is a tuple-type
|
||||
return tile_windows[number<0>{}].load(
|
||||
tile_windows, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
return tile_windows[number<0>{}].load(tile_windows,
|
||||
elementwise,
|
||||
DstDataType_{},
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
|
||||
|
||||
@@ -184,15 +184,17 @@ struct tile_window_with_static_distribution
|
||||
*/
|
||||
template <typename... TileWindow_,
|
||||
typename ElementWise_,
|
||||
typename DstDataType_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(const ck_tile::tuple<TileWindow_...>& tile_windows,
|
||||
ElementWise_ elementwise,
|
||||
DstDataType_ = {},
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
auto dst_tensor = make_static_distributed_tensor<DstDataType_>(tile_dstr);
|
||||
load(dst_tensor,
|
||||
tile_windows,
|
||||
elementwise,
|
||||
@@ -208,7 +210,7 @@ struct tile_window_with_static_distribution
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
|
||||
const ck_tile::tuple<TileWindow_...>& tile_windows,
|
||||
ElementWise_ elementwise,
|
||||
ElementWise_ elementwise = {},
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user