[CK_TILE] Add type traits to detect tile window types at compile time (#2158)

* added WindowType enum to tile_window_structs and static assert checks in computev4 pipeline

* added type traits instead of enum to tile_window() and tile_window_linear() with debug comments

* removed comments, added documentation and clang format
This commit is contained in:
Aviral Goel
2025-05-07 02:00:39 -05:00
committed by GitHub
parent 8a0d659f92
commit 769336b640
3 changed files with 130 additions and 0 deletions

View File

@@ -1164,4 +1164,82 @@ CK_TILE_DEVICE void move_tile_window(
window.move(step);
}
/**
* @brief Type trait to determine if a type is a tile window with static distribution.
*
* Defaults to `false_type`. Specializations define when the trait evaluates to `true`.
*
* @tparam T The type to check.
*/
template <typename T>
struct is_tile_window_with_static_distribution : std::false_type
{
};
/**
* @brief Specialization for `tile_window_with_static_distribution` to evaluate to `true_type`.
*
* @tparam BottomTensorView_ Bottom tensor view type of the tile window.
* @tparam WindowLengths_ Static window lengths.
* @tparam StaticTileDistribution_ Tile distribution policy.
* @tparam NumCoord Number of coordinate dimensions.
*/
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord>
struct is_tile_window_with_static_distribution<
tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>> : std::true_type
{
};
/**
* @brief Helper variable template to check if a type is a tile window with static distribution.
*
* Equivalent to `is_tile_window_with_static_distribution<T>::value`.
*
* @tparam T The type to check.
*/
template <typename T>
inline constexpr bool is_tile_window_with_static_distribution_v =
is_tile_window_with_static_distribution<T>::value;
/**
* @brief Type trait to determine if a type is a tile window with static lengths.
*
* Defaults to `false_type`. Specializations define when the trait evaluates to `true`.
*
* @tparam T The type to check.
*/
template <typename T>
struct is_tile_window_with_static_lengths : std::false_type
{
};
/**
* @brief Specialization for `tile_window_with_static_lengths` to evaluate to `true_type`.
*
* @tparam BottomTensorView_ Bottom tensor view type of the tile window.
* @tparam WindowLengths_ Static window lengths.
*/
template <typename BottomTensorView_, typename WindowLengths_>
struct is_tile_window_with_static_lengths<
tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>> : std::true_type
{
};
/**
* @brief Helper variable template to check if a type is a tile window with static lengths.
*
* Equivalent to `is_tile_window_with_static_lengths<T>::value`.
*
* @tparam T The type to check.
*/
template <typename T>
inline constexpr bool is_tile_window_with_static_lengths_v =
is_tile_window_with_static_lengths<T>::value;
} // namespace ck_tile

View File

@@ -44,6 +44,7 @@ template <typename BottomTensorView_,
typename LinearBottomDims_>
struct tile_window_linear
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
@@ -1215,4 +1216,49 @@ CK_TILE_DEVICE void move_tile_window(
window.move(step);
}
/**
* @brief Type trait to determine if a type is a linear tile window.
*
* Defaults to `false_type`. Specialized to `true_type` for types that match
* `tile_window_linear<...>`.
*
* @tparam T The type to check.
*/
template <typename T>
struct is_tile_window_linear : std::false_type
{
};
/**
* @brief Specialization of `is_tile_window_linear` for `tile_window_linear`.
*
* Evaluates to `true_type` if the type is a `tile_window_linear` with the given template
* parameters.
*
* @tparam BottomTensorView_ Bottom tensor view type of the tile window.
* @tparam WindowLengths_ Static window lengths.
* @tparam StaticTileDistribution_ Tile distribution policy.
* @tparam LinearBottomDims_ Dimensions of the bottom tensor view that participate in linearization.
*/
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename LinearBottomDims_>
struct is_tile_window_linear<tile_window_linear<BottomTensorView_,
WindowLengths_,
StaticTileDistribution_,
LinearBottomDims_>> : std::true_type
{
};
/**
* @brief Helper variable template to check if a type is a linear tile window.
*
* Equivalent to `is_tile_window_linear<T>::value`.
*
* @tparam T The type to check.
*/
template <typename T>
inline constexpr bool is_tile_window_linear_v = is_tile_window_linear<T>::value;
} // namespace ck_tile

View File

@@ -337,6 +337,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{0, 0},
BLdsTileDistr);
static_assert(
!(is_tile_window_linear_v<decltype(a_lds_ld_window0)>)&&!(is_tile_window_linear_v<decltype(a_lds_ld_window1)>)&&!(
is_tile_window_linear_v<
decltype(b_lds_ld_window0)>)&&!(is_tile_window_linear_v<decltype(b_lds_ld_window1)>),
"LDS windows must not be linear");
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);