[CK_TILE] Add 2:4 structured sparsity support for fp16 gemm (#1957)

* add structured sparsity fp16 support for gemm

* added reviewer suggestions

* update changelog

* update changelog

* add reviewers suggestions

* Minor fix

* clang fix

* fix doxygen
This commit is contained in:
jakpiase
2025-04-11 12:18:26 +02:00
committed by GitHub
parent 5f885d2b7a
commit 6c61f4d237
13 changed files with 401 additions and 20 deletions

View File

@@ -364,6 +364,49 @@ struct FillConstant
}
};
//----------------------------------------------------------------------------------------------
/// @brief Transforms given input to fit 2:4 structured sparsity pattern so
/// every subgroup of 4 elements contain at most 2 non-zero elements
template <typename T>
struct AdjustToStructuredSparsity
{
size_t start{0};
// masks represent all valid 2:4 structured sparsity permutations
// clang-format off
static constexpr int32_t masks[] = {0, 0, 1, 1,
0, 1, 0, 1,
0, 1, 1, 0,
1, 0, 0, 1,
1, 0, 1, 0,
1, 1, 0, 0,
0, 0, 0, 1,
0, 0, 1, 0,
0, 1, 0, 0,
1, 0, 0, 0};
// clang-format on
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::transform(first, last, first, [=, index = start](T val) mutable {
auto tmp = val * masks[index % (sizeof(masks) / sizeof(int32_t))];
index += 1;
return type_convert<T>(tmp);
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const
-> std::void_t<decltype(std::declval<const AdjustToStructuredSparsity&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
template <typename T, bool UseCos = true, bool UseAbs = false>
struct FillTrigValue
{