[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)

* chore: split block scale example instances in more separate files to speed up compile times

* wip: fp4 scaffolding for abquant

* feat: add fp4 decoding-while-loading to abquant pipeline

* feat: add support for fp4 CPU verification in abquant

* chore: add time tracking to reference calculation

* feat: add a4w4 test for blockscale gemm

* feat: optimize reference calculation by preconverting values to AccType

* feat: add fp4 to fp8 look-up table

* fix: reference to wrong ComputeDataType field in QuantProblem

* feat: type utilities for determining MFMA compute types

* feat: packed fp4 for abquant weight preshuffle

* feat: add separate tests for a4w4 base case, padding and preshuffleB

* fix: fp4 conversion on gfx950 attempting to use non-supported method

* fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size

* chore: add fp4 preshuffleb mode to block scale example

* chore: sanity check for packed types being 1 byte

* chore: clarify tensor dimension indices with constants

* chore: replace traits check with specialized check for packed types

* style: some minor refactoring and cleanup

* fix: correct conversion table for FNUZ fp8

* chore: add fp4 instances to main abquant instances again

* chore: use same initialization branch for int4 and fp4

* chore: add missing initialization for fp4 in block scale gemm example

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Erwin Terpstra
2026-01-30 12:40:50 +01:00
committed by GitHub
parent 565fea2645
commit 6a6177a246
28 changed files with 642 additions and 175 deletions

View File

@@ -4,11 +4,12 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
namespace ck_tile {
template <typename DstDataType, index_t UnaryOpSize>
template <typename SrcDataType, typename DstDataType, index_t UnaryOpSize>
struct InterleavedPKTypeLoader
{
template <typename WarpWindow, typename WarpTile>
@@ -21,10 +22,15 @@ struct InterleavedPKTypeLoader
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
const auto in_dstr_tensors = load_tile(warp_window);
using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize)));
// NOTE: we rely on types packing neatly here
using RawSrcType = typename SrcDataType::type;
constexpr auto PackedSize = numeric_traits<SrcDataType>::PackedSize;
using SrcVectorType = ext_vector_t<RawSrcType, UnaryOpSize / PackedSize>;
using DstVectorType = ext_vector_t<DstDataType, UnaryOpSize>;
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(warp_tile.get_thread_buffer().template get_as<DstVectorType>()(i),
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
in_dstr_tensors.get_thread_buffer().template get_as<SrcVectorType>()[i]);
});
}
};
@@ -37,10 +43,11 @@ template <typename SrcDataType,
typename WarpWindow>
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
{
if constexpr(std::is_same_v<SrcDataType, pk_int4_t>)
if constexpr(is_packed_type_v<SrcDataType>)
{
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t");
InterleavedPKTypeLoader<DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t");
InterleavedPKTypeLoader<SrcDataType, DstDataType, UnaryOpSize>::load_interleaved_pk_type(
dst, src);
}
else if constexpr(LoadTranspose)
{