mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)
* chore: split block scale example instances in more separate files to speed up compile times
* wip: fp4 scaffolding for abquant
* feat: add fp4 decoding-while-loading to abquant pipeline
* feat: add support for fp4 CPU verification in abquant
* chore: add time tracking to reference calculation
* feat: add a4w4 test for blockscale gemm
* feat: optimize reference calculation by preconverting values to AccType
* feat: add fp4 to fp8 look-up table
* fix: reference to wrong ComputeDataType field in QuantProblem
* feat: type utilities for determining MFMA compute types
* feat: packed fp4 for abquant weight preshuffle
* feat: add separate tests for a4w4 base case, padding and preshuffleB
* fix: fp4 conversion on gfx950 attempting to use non-supported method
* fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size
* chore: add fp4 preshuffleb mode to block scale example
* chore: sanity check for packed types being 1 byte
* chore: clarify tensor dimension indices with constants
* chore: replace traits check with specialized check for packed types
* style: some minor refactoring and cleanup
* fix: correct conversion table for FNUZ fp8
* chore: add fp4 instances to main abquant instances again
* chore: use same initialization branch for int4 and fp4
* chore: add missing initialization for fp4 in block scale gemm example
---------
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
[ROCm/composable_kernel commit: 6a6177a246]
This commit is contained in:
54
include/ck_tile/core/utility/mixed_prec_compute_type.hpp
Normal file
54
include/ck_tile/core/utility/mixed_prec_compute_type.hpp
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Helper method to automatically determine compute type
|
||||
// Selects the largest type of the two. If both of them are packed data types, defaults to fp8.
|
||||
template <typename ADataType, typename BDataType>
|
||||
struct auto_compute_type
|
||||
{
|
||||
using LargestInputType = largest_type_t<ADataType, BDataType>;
|
||||
|
||||
// Sanity check: there are no packed types larger than 1 byte yet, but if we add them
|
||||
// this logic should change
|
||||
static_assert(!is_packed_type_v<LargestInputType> || sizeof(LargestInputType) == sizeof(fp8_t));
|
||||
|
||||
using type = std::conditional_t<is_packed_type_v<LargestInputType>, fp8_t, LargestInputType>;
|
||||
};
|
||||
|
||||
// Helper method to determine compute type, defaulting an explicitly passed-in compute type
|
||||
template <typename ComputeDataType, typename ADataType, typename BDataType>
|
||||
struct mixed_prec_compute_type
|
||||
{
|
||||
using type = std::conditional_t<std::is_void_v<ComputeDataType>,
|
||||
typename auto_compute_type<ADataType, BDataType>::type,
|
||||
ComputeDataType>;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename ComputeDataType, typename ADataType, typename BDataType>
|
||||
using mixed_prec_compute_type_t =
|
||||
typename detail::mixed_prec_compute_type<ComputeDataType, ADataType, BDataType>::type;
|
||||
|
||||
// Helper method to determine compute type, defaulting to input data type
|
||||
// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed,
|
||||
// ComputeDataType is used.
|
||||
template <typename ThisDataType, typename OtherDataType, typename ComputeDataType>
|
||||
using mixed_prec_compute_type_from_input_t = std::conditional_t<
|
||||
is_packed_type_v<ThisDataType>,
|
||||
std::conditional_t<is_packed_type_v<OtherDataType>, ComputeDataType, OtherDataType>,
|
||||
ThisDataType>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -4,6 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
@@ -187,4 +189,19 @@ template <typename Tuple_, std::size_t Idx, typename DefaultType>
|
||||
using tuple_element_or_default_t =
|
||||
typename tuple_element_or_default<Tuple_, Idx, DefaultType>::type;
|
||||
|
||||
// Helper struct to determine if a type is packed (more than 1 element per byte)
|
||||
template <typename T>
|
||||
struct is_packed_type
|
||||
{
|
||||
static constexpr bool value = numeric_traits<T>::PackedSize > 1;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool is_packed_type_v = is_packed_type<T>::value;
|
||||
|
||||
// Helper definition to take the largest sizes type
|
||||
template <typename ADataType, typename BDataType>
|
||||
using largest_type_t =
|
||||
std::conditional_t<sizeof(ADataType) >= sizeof(BDataType), ADataType, BDataType>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user