[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)

* chore: split block scale example instances in more separate files to speed up compile times

* wip: fp4 scaffolding for abquant

* feat: add fp4 decoding-while-loading to abquant pipeline

* feat: add support for fp4 CPU verification in abquant

* chore: add time tracking to reference calculation

* feat: add a4w4 test for blockscale gemm

* feat: optimize reference calculation by preconverting values to AccType

* feat: add fp4 to fp8 look-up table

* fix: reference to wrong ComputeDataType field in QuantProblem

* feat: type utilities for determining MFMA compute types

* feat: packed fp4 for abquant weight preshuffle

* feat: add separate tests for a4w4 base case, padding and preshuffleB

* fix: fp4 conversion on gfx950 attempting to use non-supported method

* fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size

* chore: add fp4 preshuffleb mode to block scale example

* chore: sanity check for packed types being 1 byte

* chore: clarify tensor dimension indices with constants

* chore: replace traits check with specialized check for packed types

* style: some minor refactoring and cleanup

* fix: correct conversion table for FNUZ fp8

* chore: add fp4 instances to main abquant instances again

* chore: use same initialization branch for int4 and fp4

* chore: add missing initialization for fp4 in block scale gemm example

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>

[ROCm/composable_kernel commit: 6a6177a246]
This commit is contained in:
Erwin Terpstra
2026-01-30 12:40:50 +01:00
committed by GitHub
parent 1b1dd65b83
commit a5824466fb
28 changed files with 642 additions and 175 deletions

View File

@@ -0,0 +1,54 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <type_traits>
namespace ck_tile {
namespace detail {
// Helper method to automatically determine compute type
// Selects the largest type of the two. If both of them are packed data types, defaults to fp8.
template <typename ADataType, typename BDataType>
struct auto_compute_type
{
using LargestInputType = largest_type_t<ADataType, BDataType>;
// Sanity check: there are no packed types larger than 1 byte yet, but if we add them
// this logic should change
static_assert(!is_packed_type_v<LargestInputType> || sizeof(LargestInputType) == sizeof(fp8_t));
using type = std::conditional_t<is_packed_type_v<LargestInputType>, fp8_t, LargestInputType>;
};
// Helper method to determine compute type, defaulting an explicitly passed-in compute type
template <typename ComputeDataType, typename ADataType, typename BDataType>
struct mixed_prec_compute_type
{
using type = std::conditional_t<std::is_void_v<ComputeDataType>,
typename auto_compute_type<ADataType, BDataType>::type,
ComputeDataType>;
};
} // namespace detail
template <typename ComputeDataType, typename ADataType, typename BDataType>
using mixed_prec_compute_type_t =
typename detail::mixed_prec_compute_type<ComputeDataType, ADataType, BDataType>::type;
// Helper method to determine compute type, defaulting to input data type
// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed,
// ComputeDataType is used.
template <typename ThisDataType, typename OtherDataType, typename ComputeDataType>
using mixed_prec_compute_type_from_input_t = std::conditional_t<
is_packed_type_v<ThisDataType>,
std::conditional_t<is_packed_type_v<OtherDataType>, ComputeDataType, OtherDataType>,
ThisDataType>;
} // namespace ck_tile

View File

@@ -4,6 +4,8 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <tuple>
#include <type_traits>
#include <stdint.h>
@@ -187,4 +189,19 @@ template <typename Tuple_, std::size_t Idx, typename DefaultType>
using tuple_element_or_default_t =
typename tuple_element_or_default<Tuple_, Idx, DefaultType>::type;
// Helper struct to determine if a type is packed (more than 1 element per byte)
template <typename T>
struct is_packed_type
{
static constexpr bool value = numeric_traits<T>::PackedSize > 1;
};
template <typename T>
static constexpr bool is_packed_type_v = is_packed_type<T>::value;
// Helper definition to take the largest sizes type
template <typename ADataType, typename BDataType>
using largest_type_t =
std::conditional_t<sizeof(ADataType) >= sizeof(BDataType), ADataType, BDataType>;
} // namespace ck_tile