mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)
* chore: split block scale example instances in more separate files to speed up compile times * wip: fp4 scaffolding for abquant * feat: add fp4 decoding-while-loading to abquant pipeline * feat: add support for fp4 CPU verification in abquant * chore: add time tracking to reference calculation * feat: add a4w4 test for blockscale gemm * feat: optimize reference calculation by preconverting values to AccType * feat: add fp4 to fp8 look-up table * fix: reference to wrong ComputeDataType field in QuantProblem * feat: type utilities for determining MFMA compute types * feat: packed fp4 for abquant weight preshuffle * feat: add separate tests for a4w4 base case, padding and preshuffleB * fix: fp4 conversion on gfx950 attempting to use non-supported method * fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size * chore: add fp4 preshuffleb mode to block scale example * chore: sanity check for packed types being 1 byte * chore: clarify tensor dimension indices with constants * chore: replace traits check with specialized check for packed types * style: some minor refactoring and cleanup * fix: correct conversion table for FNUZ fp8 * chore: add fp4 instances to main abquant instances again * chore: use same initialization branch for int4 and fp4 * chore: add missing initialization for fp4 in block scale gemm example --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -4,11 +4,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DstDataType, index_t UnaryOpSize>
|
||||
template <typename SrcDataType, typename DstDataType, index_t UnaryOpSize>
|
||||
struct InterleavedPKTypeLoader
|
||||
{
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
@@ -21,10 +22,15 @@ struct InterleavedPKTypeLoader
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
|
||||
using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize)));
|
||||
// NOTE: we rely on types packing neatly here
|
||||
using RawSrcType = typename SrcDataType::type;
|
||||
constexpr auto PackedSize = numeric_traits<SrcDataType>::PackedSize;
|
||||
|
||||
using SrcVectorType = ext_vector_t<RawSrcType, UnaryOpSize / PackedSize>;
|
||||
using DstVectorType = ext_vector_t<DstDataType, UnaryOpSize>;
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<DstVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<SrcVectorType>()[i]);
|
||||
});
|
||||
}
|
||||
};
|
||||
@@ -37,10 +43,11 @@ template <typename SrcDataType,
|
||||
typename WarpWindow>
|
||||
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
|
||||
{
|
||||
if constexpr(std::is_same_v<SrcDataType, pk_int4_t>)
|
||||
if constexpr(is_packed_type_v<SrcDataType>)
|
||||
{
|
||||
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t");
|
||||
InterleavedPKTypeLoader<DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
|
||||
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t");
|
||||
InterleavedPKTypeLoader<SrcDataType, DstDataType, UnaryOpSize>::load_interleaved_pk_type(
|
||||
dst, src);
|
||||
}
|
||||
else if constexpr(LoadTranspose)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user