diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 3bb728df23..716b1f4ecb 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1164,4 +1164,82 @@ CK_TILE_DEVICE void move_tile_window( window.move(step); } +/** + * @brief Type trait to determine if a type is a tile window with static distribution. + * + * Defaults to `false_type`. Specializations define when the trait evaluates to `true`. + * + * @tparam T The type to check. + */ +template +struct is_tile_window_with_static_distribution : std::false_type +{ +}; + +/** + * @brief Specialization for `tile_window_with_static_distribution` to evaluate to `true_type`. + * + * @tparam BottomTensorView_ Bottom tensor view type of the tile window. + * @tparam WindowLengths_ Static window lengths. + * @tparam StaticTileDistribution_ Tile distribution policy. + * @tparam NumCoord Number of coordinate dimensions. + */ +template +struct is_tile_window_with_static_distribution< + tile_window_with_static_distribution> : std::true_type +{ +}; + +/** + * @brief Helper variable template to check if a type is a tile window with static distribution. + * + * Equivalent to `is_tile_window_with_static_distribution::value`. + * + * @tparam T The type to check. + */ +template +inline constexpr bool is_tile_window_with_static_distribution_v = + is_tile_window_with_static_distribution::value; + +/** + * @brief Type trait to determine if a type is a tile window with static lengths. + * + * Defaults to `false_type`. Specializations define when the trait evaluates to `true`. + * + * @tparam T The type to check. + */ +template +struct is_tile_window_with_static_lengths : std::false_type +{ +}; + +/** + * @brief Specialization for `tile_window_with_static_lengths` to evaluate to `true_type`. + * + * @tparam BottomTensorView_ Bottom tensor view type of the tile window. + * @tparam WindowLengths_ Static window lengths. + */ +template +struct is_tile_window_with_static_lengths< + tile_window_with_static_lengths> : std::true_type +{ +}; + +/** + * @brief Helper variable template to check if a type is a tile window with static lengths. + * + * Equivalent to `is_tile_window_with_static_lengths::value`. + * + * @tparam T The type to check. + */ +template +inline constexpr bool is_tile_window_with_static_lengths_v = + is_tile_window_with_static_lengths::value; + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 1e24e660f6..5ecaf5ca17 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -44,6 +44,7 @@ template struct tile_window_linear { + using BottomTensorView = remove_reference_t; using WindowLengths = remove_cvref_t; using TileDstr = remove_cvref_t; @@ -1215,4 +1216,49 @@ CK_TILE_DEVICE void move_tile_window( window.move(step); } +/** + * @brief Type trait to determine if a type is a linear tile window. + * + * Defaults to `false_type`. Specialized to `true_type` for types that match + * `tile_window_linear<...>`. + * + * @tparam T The type to check. + */ +template +struct is_tile_window_linear : std::false_type +{ +}; + +/** + * @brief Specialization of `is_tile_window_linear` for `tile_window_linear`. + * + * Evaluates to `true_type` if the type is a `tile_window_linear` with the given template + * parameters. + * + * @tparam BottomTensorView_ Bottom tensor view type of the tile window. + * @tparam WindowLengths_ Static window lengths. + * @tparam StaticTileDistribution_ Tile distribution policy. + * @tparam LinearBottomDims_ Dimensions of the bottom tensor view that participate in linearization. + */ +template +struct is_tile_window_linear> : std::true_type +{ +}; + +/** + * @brief Helper variable template to check if a type is a linear tile window. + * + * Equivalent to `is_tile_window_linear::value`. + * + * @tparam T The type to check. + */ +template +inline constexpr bool is_tile_window_linear_v = is_tile_window_linear::value; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index 667bb80ce9..6535f612f1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -337,6 +337,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 {0, 0}, BLdsTileDistr); + static_assert( + !(is_tile_window_linear_v)&&!(is_tile_window_linear_v)&&!( + is_tile_window_linear_v< + decltype(b_lds_ld_window0)>)&&!(is_tile_window_linear_v), + "LDS windows must not be linear"); + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);