mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[CK-Tile] Fix quant example code (#2813)
* initial commit * remove extra files * fixing errors * updated ReadMe file for mapping of diff quants with diff configs * addressing review comments * addressing review comments * Resolved merge conflicts * [CK TILE GEMM] Replace get_preshuffle_or with is_quantpreshuffle_enabled The get_preshuffle_or was not working as expected, which led to incorrect behavior in the quantization preshuffle process. This change replaces it with the more reliable is_quantpreshuffle_enabled function to properly determine when preshuffle should be applied. --------- Co-authored-by: Cong Ma <congma13@amd.com>
This commit is contained in:
@@ -70,17 +70,21 @@ struct get_bq_data_type_or<T, Default>
|
||||
using type = typename T::BQDataType;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
struct get_preshuffle_or
|
||||
{
|
||||
using type = Default;
|
||||
template <typename T>
|
||||
concept HasStaticPreshuffleQuant = requires {
|
||||
{ T::PreshuffleQuant } -> std::convertible_to<decltype(T::PreshuffleQuant)>;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
requires requires { typename T::PreshuffleQuant; }
|
||||
struct get_preshuffle_or<T, Default>
|
||||
template <typename T>
|
||||
struct is_quantpreshuffle_enabled
|
||||
{
|
||||
using type = typename T::PreshuffleQuant;
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <HasStaticPreshuffleQuant T>
|
||||
struct is_quantpreshuffle_enabled<T>
|
||||
{
|
||||
static constexpr auto value = T::PreshuffleQuant;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
@@ -198,9 +202,9 @@ struct QuantGemmKernel
|
||||
using BQLayout = remove_cvref_t<
|
||||
typename detail::get_bq_layout_or<GemmPipeline, typename GemmPipeline::BLayout>::type>;
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool PreshuffleQuant = remove_cvref_t<
|
||||
typename detail::get_preshuffle_or<GemmPipeline, std::false_type>::type>::value;
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool PreshuffleQuant =
|
||||
detail::is_quantpreshuffle_enabled<GemmPipeline_>::value;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
|
||||
Reference in New Issue
Block a user