[CK-Tile] Fix quant example code (#2813)

* initial commit

* remove extra files

* fixing errors

* updated ReadMe file for mapping of diff quants with diff configs

* addressing review comments

* addressing review comments

* Resolved merge conflicts

* [CK TILE GEMM] Replace get_preshuffle_or with is_quantpreshuffle_enabled

The get_preshuffle_or was not working as expected, which led to incorrect behavior
in the quantization preshuffle process. This change replaces it with the more reliable
is_quantpreshuffle_enabled function to properly determine when preshuffle should be applied.

---------

Co-authored-by: Cong Ma <congma13@amd.com>
This commit is contained in:
Khushbu Agarwal
2025-09-10 17:15:39 -07:00
committed by GitHub
parent b4207c01c7
commit 80a61afb9b
10 changed files with 343 additions and 1288 deletions

View File

@@ -70,17 +70,21 @@ struct get_bq_data_type_or<T, Default>
using type = typename T::BQDataType;
};
template <typename T, typename Default>
struct get_preshuffle_or
{
using type = Default;
template <typename T>
concept HasStaticPreshuffleQuant = requires {
{ T::PreshuffleQuant } -> std::convertible_to<decltype(T::PreshuffleQuant)>;
};
template <typename T, typename Default>
requires requires { typename T::PreshuffleQuant; }
struct get_preshuffle_or<T, Default>
template <typename T>
struct is_quantpreshuffle_enabled
{
using type = typename T::PreshuffleQuant;
static constexpr bool value = false;
};
template <HasStaticPreshuffleQuant T>
struct is_quantpreshuffle_enabled<T>
{
static constexpr auto value = T::PreshuffleQuant;
};
} // namespace detail
@@ -198,9 +202,9 @@ struct QuantGemmKernel
using BQLayout = remove_cvref_t<
typename detail::get_bq_layout_or<GemmPipeline, typename GemmPipeline::BLayout>::type>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool PreshuffleQuant = remove_cvref_t<
typename detail::get_preshuffle_or<GemmPipeline, std::false_type>::type>::value;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool PreshuffleQuant =
detail::is_quantpreshuffle_enabled<GemmPipeline_>::value;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;