mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK_TILE] Add type traits to detect tile window types at compile time (#2158)
* added WindowType enum to tile_window_structs and static assert checks in computev4 pipeline * added type traits instead of enum to tile_window() and tile_window_linear() with debug comments * removed comments, added documentation and clang format
This commit is contained in:
@@ -1164,4 +1164,82 @@ CK_TILE_DEVICE void move_tile_window(
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Type trait to determine if a type is a tile window with static distribution.
|
||||
*
|
||||
* Defaults to `false_type`. Specializations define when the trait evaluates to `true`.
|
||||
*
|
||||
* @tparam T The type to check.
|
||||
*/
|
||||
template <typename T>
|
||||
struct is_tile_window_with_static_distribution : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization for `tile_window_with_static_distribution` to evaluate to `true_type`.
|
||||
*
|
||||
* @tparam BottomTensorView_ Bottom tensor view type of the tile window.
|
||||
* @tparam WindowLengths_ Static window lengths.
|
||||
* @tparam StaticTileDistribution_ Tile distribution policy.
|
||||
* @tparam NumCoord Number of coordinate dimensions.
|
||||
*/
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord>
|
||||
struct is_tile_window_with_static_distribution<
|
||||
tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Helper variable template to check if a type is a tile window with static distribution.
|
||||
*
|
||||
* Equivalent to `is_tile_window_with_static_distribution<T>::value`.
|
||||
*
|
||||
* @tparam T The type to check.
|
||||
*/
|
||||
template <typename T>
|
||||
inline constexpr bool is_tile_window_with_static_distribution_v =
|
||||
is_tile_window_with_static_distribution<T>::value;
|
||||
|
||||
/**
|
||||
* @brief Type trait to determine if a type is a tile window with static lengths.
|
||||
*
|
||||
* Defaults to `false_type`. Specializations define when the trait evaluates to `true`.
|
||||
*
|
||||
* @tparam T The type to check.
|
||||
*/
|
||||
template <typename T>
|
||||
struct is_tile_window_with_static_lengths : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization for `tile_window_with_static_lengths` to evaluate to `true_type`.
|
||||
*
|
||||
* @tparam BottomTensorView_ Bottom tensor view type of the tile window.
|
||||
* @tparam WindowLengths_ Static window lengths.
|
||||
*/
|
||||
template <typename BottomTensorView_, typename WindowLengths_>
|
||||
struct is_tile_window_with_static_lengths<
|
||||
tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Helper variable template to check if a type is a tile window with static lengths.
|
||||
*
|
||||
* Equivalent to `is_tile_window_with_static_lengths<T>::value`.
|
||||
*
|
||||
* @tparam T The type to check.
|
||||
*/
|
||||
template <typename T>
|
||||
inline constexpr bool is_tile_window_with_static_lengths_v =
|
||||
is_tile_window_with_static_lengths<T>::value;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -44,6 +44,7 @@ template <typename BottomTensorView_,
|
||||
typename LinearBottomDims_>
|
||||
struct tile_window_linear
|
||||
{
|
||||
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
@@ -1215,4 +1216,49 @@ CK_TILE_DEVICE void move_tile_window(
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Type trait to determine if a type is a linear tile window.
|
||||
*
|
||||
* Defaults to `false_type`. Specialized to `true_type` for types that match
|
||||
* `tile_window_linear<...>`.
|
||||
*
|
||||
* @tparam T The type to check.
|
||||
*/
|
||||
template <typename T>
|
||||
struct is_tile_window_linear : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization of `is_tile_window_linear` for `tile_window_linear`.
|
||||
*
|
||||
* Evaluates to `true_type` if the type is a `tile_window_linear` with the given template
|
||||
* parameters.
|
||||
*
|
||||
* @tparam BottomTensorView_ Bottom tensor view type of the tile window.
|
||||
* @tparam WindowLengths_ Static window lengths.
|
||||
* @tparam StaticTileDistribution_ Tile distribution policy.
|
||||
* @tparam LinearBottomDims_ Dimensions of the bottom tensor view that participate in linearization.
|
||||
*/
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename LinearBottomDims_>
|
||||
struct is_tile_window_linear<tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
LinearBottomDims_>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Helper variable template to check if a type is a linear tile window.
|
||||
*
|
||||
* Equivalent to `is_tile_window_linear<T>::value`.
|
||||
*
|
||||
* @tparam T The type to check.
|
||||
*/
|
||||
template <typename T>
|
||||
inline constexpr bool is_tile_window_linear_v = is_tile_window_linear<T>::value;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user