mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)
* chore: split block scale example instances in more separate files to speed up compile times * wip: fp4 scaffolding for abquant * feat: add fp4 decoding-while-loading to abquant pipeline * feat: add support for fp4 CPU verification in abquant * chore: add time tracking to reference calculation * feat: add a4w4 test for blockscale gemm * feat: optimize reference calculation by preconverting values to AccType * feat: add fp4 to fp8 look-up table * fix: reference to wrong ComputeDataType field in QuantProblem * feat: type utilities for determining MFMA compute types * feat: packed fp4 for abquant weight preshuffle * feat: add separate tests for a4w4 base case, padding and preshuffleB * fix: fp4 conversion on gfx950 attempting to use non-supported method * fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size * chore: add fp4 preshuffleb mode to block scale example * chore: sanity check for packed types being 1 byte * chore: clarify tensor dimension indices with constants * chore: replace traits check with specialized check for packed types * style: some minor refactoring and cleanup * fix: correct conversion table for FNUZ fp8 * chore: add fp4 instances to main abquant instances again * chore: use same initialization branch for int4 and fp4 * chore: add missing initialization for fp4 in block scale gemm example --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
@@ -255,17 +256,26 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
// Determine compute types to use
|
||||
// This logic defaults to A/B DataType, but if one of them is packed falls back to the other
|
||||
// If both are packed, it falls back to the explicitly defined ComputeDataType in the
|
||||
// problem It might be a good idea to use ComputeDataType anyway, but that would break how
|
||||
// this behaviour used to work
|
||||
using ATypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::ComputeDataType>;
|
||||
using BTypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::BDataType,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::ComputeDataType>;
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
constexpr index_t KLaneBytes =
|
||||
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
|
||||
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
|
||||
Reference in New Issue
Block a user