mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE][FMHA] Integrate FAv2 & FAv3 (WIP) in the single fmha_fwd() API (#3153)
* Let fmha_fwd_v3() compatible with fmha_fwd() * Decouple get_fwd_blobs() and FmhaFwdKernel * Decouple compatibility checks from get_fwd_blobs() * Extract product feature checks out from get_fwd_blobs() * Remove duplicated code in factories and redundant checks * Remove FmhaFwdKernel<>::GetName() * Let FmhaFwdApiPool support pipelines with different mask_impl * Add tile setting for fmha fwd v3 pipeline * Add fwd v3 instances to tile_example_fmha_fwd manually * Remove unused function import * Undo irrelevant changes * Remove fwd v3 instances from tile_example_fmha_fwd * Finish fmha fwd v3 kernel instance codegen * Fix formatting * Remove unused F_idx attribute * Add is_generic_attention_mask<> traits * Add constraints to the fmha fwd v3 pipeline * Unify traits & problem used for fmha fwd v3 * Unify kernel launch code for fmha fwd v2 & v3 * Unify kernel template selection logic * Use same kernel codegen template for both v2 & v3 * Rename api() property as render() method * Allow specifying filter for fmha fwd api pool * Allow specifying function name when rendering api pool items * Separate fmha fwd v3 kernel dispatching logic from v2 * Remove lambda assignment * Add simple v2/v3 dispatch logic * Stop generating empty if-clauses Skip iterating over dictionaries that have no traits, and avoid assigning i_* to them. * Use "".join() to concatenate fmha fwd api string content * Add more feature checks for fmha fwd v3 pipeline * Check features before dispatch to fmha_fwd_v3() * Add more feature checks for fmha_fwd_v3() * Add missing filter call * Use Tuple to reserve the dtype orders * Fix wrong pipeline matching logic * Add fmha fwd v3 group mode instances * Add functor_transform<> * Add type constraints to make_tile_window() * Remove fmha fwd v3 example * Fix wrong product(aiter mha_fwd()) config * Fix wrong fmha fwd v2/v3 selection logic * Fix formatting * Add comment to warning v3 kernel users * Fix wrong codegen logics * Remove unnecessary param * Fix format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -1552,6 +1552,81 @@ CK_TILE_HOST_DEVICE static void print(const indexing<UpLength, IndexingAdaptor>&
|
||||
printf("}");
|
||||
}
|
||||
|
||||
template <typename Functor, typename LowLength>
|
||||
struct functor_transform : public base_transform<1, 1>
|
||||
{
|
||||
using LowerIndex = multi_index<1>;
|
||||
using UpperIndex = multi_index<1>;
|
||||
|
||||
using UpLengths = decltype(make_tuple(LowLength{}));
|
||||
|
||||
Functor functor_;
|
||||
UpLengths up_lengths_;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr functor_transform() = default;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr functor_transform(const Functor& functor,
|
||||
const LowLength& low_length)
|
||||
: functor_{functor}, up_lengths_{make_tuple(low_length)}
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(number<0>{}) = functor_(idx_up[number<0>{}]);
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff&,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& up_idx) const
|
||||
{
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
|
||||
UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
const auto idx_low_old = idx_low;
|
||||
calculate_lower_index(idx_low, up_idx);
|
||||
idx_diff_low = idx_low - idx_low_old;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool
|
||||
is_valid_upper_index_always_mapped_to_valid_lower_index()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE static constexpr bool
|
||||
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
// Note: When using functor_transform, ensure that the transformed coordinates
|
||||
// are always valid for vectorized load/store operations.
|
||||
template <typename LowVectorLengths, typename LowVectorStrides>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
|
||||
const LowVectorStrides& low_vector_strides)
|
||||
{
|
||||
return make_tuple(low_vector_lengths, low_vector_strides);
|
||||
}
|
||||
};
|
||||
|
||||
//*******************************************************************************************************
|
||||
|
||||
template <typename LowLength>
|
||||
@@ -1671,6 +1746,13 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
|
||||
return offset<LowLength, OffsetLength>{low_length, offset_length};
|
||||
}
|
||||
|
||||
template <typename Functor, typename LowLength>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_functor_transform(const Functor& functor,
|
||||
const LowLength& low_length)
|
||||
{
|
||||
return functor_transform<Functor, LowLength>{functor, low_length};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
||||
|
||||
@@ -1263,7 +1263,9 @@ struct tile_window_with_static_lengths
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename = std::enable_if_t<is_tensor_view_v<TensorView_>>>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
@@ -1310,7 +1312,10 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
|
||||
tile_distribution);
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
|
||||
template <typename TensorView,
|
||||
typename WindowLengths,
|
||||
typename StaticTileDistribution,
|
||||
typename = std::enable_if_t<is_tile_distribution_v<StaticTileDistribution>>>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
|
||||
Reference in New Issue
Block a user