mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant (#3603)
* chore: split block scale example instances in more separate files to speed up compile times * wip: fp4 scaffolding for abquant * feat: add fp4 decoding-while-loading to abquant pipeline * feat: add support for fp4 CPU verification in abquant * chore: add time tracking to reference calculation * feat: add a4w4 test for blockscale gemm * feat: optimize reference calculation by preconverting values to AccType * feat: add fp4 to fp8 look-up table * fix: reference to wrong ComputeDataType field in QuantProblem * feat: type utilities for determining MFMA compute types * feat: packed fp4 for abquant weight preshuffle * feat: add separate tests for a4w4 base case, padding and preshuffleB * fix: fp4 conversion on gfx950 attempting to use non-supported method * fix: test case was using quant group sizes which don't work on gfx950 due to larger mfma tile size * chore: add fp4 preshuffleb mode to block scale example * chore: sanity check for packed types being 1 byte * chore: clarify tensor dimension indices with constants * chore: replace traits check with specialized check for packed types * style: some minor refactoring and cleanup * fix: correct conversion table for FNUZ fp8 * chore: add fp4 instances to main abquant instances again * chore: use same initialization branch for int4 and fp4 * chore: add missing initialization for fp4 in block scale gemm example --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -1544,7 +1544,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, pk_fp4_raw_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_fp4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
(std::is_same<T, pk_fp4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<T, N>;
|
||||
|
||||
@@ -1414,7 +1414,7 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(std::is_same<T, pk_int4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
|
||||
(std::is_same<T, pk_fp4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32))),
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<T, N>;
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <cmath>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
#if defined(__gfx950__)
|
||||
@@ -23,6 +24,12 @@ using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2)));
|
||||
#else
|
||||
using fp8x2_t = fp8_t __attribute__((ext_vector_type(2)));
|
||||
#endif
|
||||
|
||||
// Helpers: constexpr-safe access to elements of ext_vector_type(2)
|
||||
// Some compilers don't allow operator[] in constant expressions for vector types.
|
||||
// We use bit_cast to a trivially copyable representation to extract lanes.
|
||||
@@ -98,6 +105,8 @@ struct pk_float4_e2m1_t
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp8_t to_fp8(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp8x2_t to_fp8x2(float scale = 1.f) const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
|
||||
@@ -105,6 +114,8 @@ struct pk_float4_e2m1_t
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp8_t() const { return to_fp8(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp8x2_t() const { return to_fp8x2(); }
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number<I>) const
|
||||
@@ -145,6 +156,49 @@ struct pk_float4_e2m1_t
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC400)), // -4
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC600)) // -6
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_OCP_FP8
|
||||
// FP8 EM4E3 (OCP) representation
|
||||
static constexpr fp8_t e2m1_to_fp8_table[16] = {
|
||||
fp8_t(static_cast<uint8_t>(0x00)), // 0
|
||||
fp8_t(static_cast<uint8_t>(0x30)), // 0.5
|
||||
fp8_t(static_cast<uint8_t>(0x38)), // 1
|
||||
fp8_t(static_cast<uint8_t>(0x3C)), // 1.5
|
||||
fp8_t(static_cast<uint8_t>(0x40)), // 2
|
||||
fp8_t(static_cast<uint8_t>(0x44)), // 3
|
||||
fp8_t(static_cast<uint8_t>(0x48)), // 4
|
||||
fp8_t(static_cast<uint8_t>(0x4C)), // 6
|
||||
fp8_t(static_cast<uint8_t>(0x00)), // -0
|
||||
fp8_t(static_cast<uint8_t>(0xB0)), // -0.5
|
||||
fp8_t(static_cast<uint8_t>(0xB8)), // -1
|
||||
fp8_t(static_cast<uint8_t>(0xBC)), // -1.5
|
||||
fp8_t(static_cast<uint8_t>(0xC0)), // -2
|
||||
fp8_t(static_cast<uint8_t>(0xC4)), // -3
|
||||
fp8_t(static_cast<uint8_t>(0xC8)), // -4
|
||||
fp8_t(static_cast<uint8_t>(0xCC)) // -6
|
||||
};
|
||||
#else // CK_TILE_USE_FNUZ_FP8
|
||||
// FP8 E4M3 FNUZ
|
||||
static constexpr fp8_t e2m1_to_fp8_table[16] = {
|
||||
fp8_t(static_cast<uint8_t>(0x00)), // 0
|
||||
fp8_t(static_cast<uint8_t>(0x38)), // 0.5
|
||||
fp8_t(static_cast<uint8_t>(0x40)), // 1
|
||||
fp8_t(static_cast<uint8_t>(0x44)), // 1.5
|
||||
fp8_t(static_cast<uint8_t>(0x48)), // 2
|
||||
fp8_t(static_cast<uint8_t>(0x4C)), // 3
|
||||
fp8_t(static_cast<uint8_t>(0x50)), // 4
|
||||
fp8_t(static_cast<uint8_t>(0x54)), // 6
|
||||
fp8_t(static_cast<uint8_t>(0x00)), // -0
|
||||
fp8_t(static_cast<uint8_t>(0xB8)), // -0.5
|
||||
fp8_t(static_cast<uint8_t>(0xC0)), // -1
|
||||
fp8_t(static_cast<uint8_t>(0xC4)), // -1.5
|
||||
fp8_t(static_cast<uint8_t>(0xC4)), // -2
|
||||
fp8_t(static_cast<uint8_t>(0xCC)), // -3
|
||||
fp8_t(static_cast<uint8_t>(0xD0)), // -4
|
||||
fp8_t(static_cast<uint8_t>(0xD4)) // -6
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -408,6 +462,27 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
|
||||
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const
|
||||
{
|
||||
// NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8
|
||||
// would be better than the naive implementation below
|
||||
// #if CK_TILE_FP4_CVT_DEVICE
|
||||
// return impl::_from_f4<fp8_t>(data, scale);
|
||||
// #else
|
||||
return fp8_t{type_convert<fp8_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale))};
|
||||
// #endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const
|
||||
{
|
||||
// NOTE: No specialized fp4 to fp8 instructions are available. Unsure whether fp4 to fp16 to fp8
|
||||
// would be better than the naive implementation below
|
||||
// #if CK_TILE_FP4_CVT_DEVICE
|
||||
// return impl::_from_f4<fp8x2_t>(data, scale);
|
||||
// #else
|
||||
return fp8x2_t{type_convert<fp8_t>(convert_to_float<pk_fp4_t>(_unpack(number<0>{}), scale)),
|
||||
type_convert<fp8_t>(convert_to_float<pk_fp4_t>(_unpack(number<1>{}), scale))};
|
||||
// #endif
|
||||
}
|
||||
#else
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
|
||||
{
|
||||
@@ -415,7 +490,8 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
|
||||
{
|
||||
return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale};
|
||||
return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale,
|
||||
e2m1_to_fp32_table[_unpack(number<1>{})] * scale};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
|
||||
{
|
||||
@@ -428,6 +504,16 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
|
||||
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[_unpack(number<1>{})]) *
|
||||
scale)};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp8_t pk_fp4_t::to_fp8(float scale) const
|
||||
{
|
||||
return type_convert<float>(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp8x2_t pk_fp4_t::to_fp8x2(float scale) const
|
||||
{
|
||||
return fp8x2_t{
|
||||
type_convert<fp8_t>(type_convert<float>(e2m1_to_fp8_table[_unpack(number<0>{})]) * scale),
|
||||
type_convert<fp8_t>(type_convert<float>(e2m1_to_fp8_table[_unpack(number<1>{})]) * scale)};
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include <stdint.h>
|
||||
@@ -23,6 +24,11 @@ struct pk_int4_t
|
||||
type data;
|
||||
CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {}
|
||||
|
||||
// NOTE: added for interface compatibility with pk_fp4_t
|
||||
// Other data types could be added for greater similarity
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
|
||||
};
|
||||
|
||||
// limits
|
||||
@@ -186,4 +192,9 @@ CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x)
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_int4_t::to_fp32x2() const
|
||||
{
|
||||
return pk_int4_t_to_fp32x2_t(*this);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/core/numeric/e8m0.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
|
||||
54
include/ck_tile/core/utility/mixed_prec_compute_type.hpp
Normal file
54
include/ck_tile/core/utility/mixed_prec_compute_type.hpp
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Helper method to automatically determine compute type
|
||||
// Selects the largest type of the two. If both of them are packed data types, defaults to fp8.
|
||||
template <typename ADataType, typename BDataType>
|
||||
struct auto_compute_type
|
||||
{
|
||||
using LargestInputType = largest_type_t<ADataType, BDataType>;
|
||||
|
||||
// Sanity check: there are no packed types larger than 1 byte yet, but if we add them
|
||||
// this logic should change
|
||||
static_assert(!is_packed_type_v<LargestInputType> || sizeof(LargestInputType) == sizeof(fp8_t));
|
||||
|
||||
using type = std::conditional_t<is_packed_type_v<LargestInputType>, fp8_t, LargestInputType>;
|
||||
};
|
||||
|
||||
// Helper method to determine compute type, defaulting an explicitly passed-in compute type
|
||||
template <typename ComputeDataType, typename ADataType, typename BDataType>
|
||||
struct mixed_prec_compute_type
|
||||
{
|
||||
using type = std::conditional_t<std::is_void_v<ComputeDataType>,
|
||||
typename auto_compute_type<ADataType, BDataType>::type,
|
||||
ComputeDataType>;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename ComputeDataType, typename ADataType, typename BDataType>
|
||||
using mixed_prec_compute_type_t =
|
||||
typename detail::mixed_prec_compute_type<ComputeDataType, ADataType, BDataType>::type;
|
||||
|
||||
// Helper method to determine compute type, defaulting to input data type
|
||||
// If "ThisDataType" is packed (4-bit), will default to "OtherDataType". If both are packed,
|
||||
// ComputeDataType is used.
|
||||
template <typename ThisDataType, typename OtherDataType, typename ComputeDataType>
|
||||
using mixed_prec_compute_type_from_input_t = std::conditional_t<
|
||||
is_packed_type_v<ThisDataType>,
|
||||
std::conditional_t<is_packed_type_v<OtherDataType>, ComputeDataType, OtherDataType>,
|
||||
ThisDataType>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -4,6 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
@@ -187,4 +189,19 @@ template <typename Tuple_, std::size_t Idx, typename DefaultType>
|
||||
using tuple_element_or_default_t =
|
||||
typename tuple_element_or_default<Tuple_, Idx, DefaultType>::type;
|
||||
|
||||
// Helper struct to determine if a type is packed (more than 1 element per byte)
|
||||
template <typename T>
|
||||
struct is_packed_type
|
||||
{
|
||||
static constexpr bool value = numeric_traits<T>::PackedSize > 1;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool is_packed_type_v = is_packed_type<T>::value;
|
||||
|
||||
// Helper definition to take the largest sizes type
|
||||
template <typename ADataType, typename BDataType>
|
||||
using largest_type_t =
|
||||
std::conditional_t<sizeof(ADataType) >= sizeof(BDataType), ADataType, BDataType>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user