MX GEMM - FP6 Example (#2419)

Adds support for MX FP6 data type in MX GEMM block pipeline version v1.
Provides an example of MX FP6 GEMM algorithm.

---------

Co-authored-by: OscarXu <huaiguxu@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: valarLip <340077269@qq.com>
Co-authored-by: Ding, Yi <yi.ding@amd.com>
Co-authored-by: feifei14119 <feiw@amd.com>
Co-authored-by: Lin, Qun <qlin@amd.com>
Co-authored-by: joye <joye@amd.com>

[ROCm/composable_kernel commit: 054f85ab7c]
This commit is contained in:
Andriy Roshchenko
2025-07-07 10:33:26 -06:00
committed by GitHub
parent 6a1b27e411
commit 67545a9d22
18 changed files with 578 additions and 95 deletions

View File

@@ -14,7 +14,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
* Added support for Stream-K version of mixed fp8/bf16 GEMM
* Added support for Multiple D GEMM
* Added GEMM pipeline for microscaling (MX) FP8/FP4 data types
* Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types
* Added support for FP16 2:4 structured sparsity to universal GEMM.
* Added support for Split K for grouped convolution backward data.
* Added logit soft-capping support for fMHA forward kernels.

View File

@@ -10,6 +10,9 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_bf8)
# add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
add_example_executable(example_gemm_mx_fp6 gemm_mx_fp6.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp6)
add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp4)
@@ -55,3 +58,7 @@ set(FP8_MXGEMM_OPTIONS)
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS})
set(FP6_MXGEMM_OPTIONS)
list(APPEND FP6_MXGEMM_OPTIONS -mavx512f)
example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS})

View File

@@ -245,6 +245,11 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
if(K % ck::packed_size_v<ADataType> != 0 || K % ck::packed_size_v<BDataType> != 0)
{
throw std::runtime_error("wrong! K must be multiple of packed size.");
};
// Hardcode scale layouts as per pipeline assumptions
// TODO: Allow user to specify scale layouts
using AScaleLayout = Row;
@@ -292,12 +297,20 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
auto a_data_element = [](float x) {
if constexpr(ck::is_same_v<ADataType, ck::f4x2_pk_t>)
return ck::type_convert<ADataType>(ck::float2_t(x));
else if constexpr(ck::packed_size_v<ADataType> == 32)
return ck::type_convert<ADataType>(ck::float32_t(x));
else if constexpr(ck::packed_size_v<ADataType> == 16)
return ck::type_convert<ADataType>(ck::float16_t(x));
else
return ck::type_convert<ADataType>(x);
};
auto b_data_element = [](float x) {
if constexpr(ck::is_same_v<BDataType, ck::f4x2_pk_t>)
return ck::type_convert<BDataType>(ck::float2_t(x));
else if constexpr(ck::packed_size_v<BDataType> == 32)
return ck::type_convert<BDataType>(ck::float32_t(x));
else if constexpr(ck::packed_size_v<BDataType> == 16)
return ck::type_convert<BDataType>(ck::float16_t(x));
else
return ck::type_convert<BDataType>(x);
};
@@ -307,30 +320,35 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
switch(config.init_method)
{
case 0: // Initializations for development and debugging
ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
ck::utils::FillConstant<ADataType>{a_data_element(0.5f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{b_data_element(2.0f)}(*b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(b_k_n_scale);
if(config.verbosity > 0)
{
std::cout << "Init A = {1}" << std::endl;
std::cout << "Init A = {0.5}" << std::endl;
std::cout << "Init A scale = {2.0}" << std::endl;
std::cout << "Init B = {0.5}" << std::endl;
std::cout << "Init B scale = {1.0}" << std::endl;
std::cout << "Init B = {2.0}" << std::endl;
std::cout << "Init B scale = {0.5}" << std::endl;
std::cout << "Expect C = {K}" << std::endl;
}
break;
case 1:
a_m_k.GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5]
b_k_n->GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5]
a_m_k.GenerateTensorDistr(
int_distr{-5, 5}, ck::identity{}, std::minstd_rand(time(nullptr))); // Z[-5,5]
b_k_n->GenerateTensorDistr(int_distr{-5, 5}); // Z[-5,5]
static_assert(ck::is_same_v<XDataType, ck::e8m0_bexp_t>);
a_m_k_scale.GenerateTensorDistr(int_distr{120, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2}
a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
break;
case 2:
a_m_k.GenerateTensorDistr(float_distr{-2.0, 2.0});
a_m_k.GenerateTensorDistr(
float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr))); // R[-2,2]
a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f});
b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0});

View File

@@ -0,0 +1,99 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gemm_mx_common.hpp"
using ADataType = ck::f6x16_pk_t;
using BDataType = ck::f6x16_pk_t;
using XDataType = ck::e8m0_bexp_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / ck::packed_size_v<ADataType>; // K dimension size per block
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XDataType, // AScaleDataType
BDataType, // BDataType
XDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Number of threads per block
128, // MPerBlock
128, // NPerBlock
KPerBlock, // KPerBlock
1, // AK1 number of elements to read at a time when transferring from global memory to LDS
1, // BK1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
int main(int argc, char* argv[])
{
return run_mx_gemm_example<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
XDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
ScaleBlockSize>(argc, argv)
? 0
: -1;
}

View File

@@ -556,6 +556,64 @@ struct Tensor
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(
ck::float2_t{ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_)))})};
else if constexpr(ck::is_same_v<T, ck::f6x32_pk_t> ||
ck::is_same_v<T, ck::bf6x32_pk_t>)
{
return ck::type_convert<T>(
ck::float32_t{ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_)))});
}
else if constexpr(ck::is_same_v<T, ck::f6x16_pk_t> ||
ck::is_same_v<T, ck::bf6x16_pk_t>)
{
return ck::type_convert<T>(
ck::float16_t{ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_)))});
}
else
static_assert(false, "Unsupported packed size for T");
};

View File

@@ -66,9 +66,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static constexpr index_t AMmaKStride = KPack;
static constexpr index_t BMmaKStride = KPack;
//> store rows/cols into thread registers in chunks of 16
//> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
static constexpr index_t KThreadChunk = 16 / sizeof(ComputeTypeA);
// store rows/cols into thread registers in chunks of 16 for FP8
// e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
// or in chunks of 32 / APackedSize for FP6/FP4
static constexpr index_t KThreadChunk = (APackedSize == 1) ? 16 : 32 / APackedSize;
static_assert(APackedSize == BPackedSize, "APackedSize must be equal to BPackedSize for now");
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;

View File

@@ -54,6 +54,8 @@ namespace device {
*
* Conditions for achieving computational load balancing on different hardware platforms can vary.
*
* \tparam KPerBlock is the number of elements in K dimension that each block processes (multiply with packed_size_v to get the actual KPerBlock)
*
* Serialized version of the algorithm:
* \code
* // E = A * B + C
@@ -117,7 +119,7 @@ template <typename ALayout,
index_t BlockSize, // Thread block size
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t KPerBlock, // multiply with packed_size_v to get the actual KPerBlock
index_t AK1,
index_t BK1,
index_t MPerXDL,

View File

@@ -419,6 +419,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MPadding)),
"f4x2_pk_t does not support K padding");
static_assert(!((is_same_v<remove_cvref_t<ADataType>, f6x16_pk_t> ||
is_same_v<remove_cvref_t<ADataType>, bf6x16_pk_t> ||
is_same_v<remove_cvref_t<ADataType>, f6x32_pk_t> ||
is_same_v<remove_cvref_t<ADataType>, bf6x32_pk_t>)&&GemmSpec !=
GemmSpecialization::Default),
"Packed F6 types do not support padding");
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)

View File

@@ -889,7 +889,6 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
const ScaleB& scale_b,
FloatC& reg_c) const
{
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
}
@@ -1224,6 +1223,27 @@ struct MfmaSelector
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<f6_t, 32, 32, f6_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<f6_t, 16, 16, f6_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<bf6_t, 32, 32, bf6_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<bf6_t, 16, 16, bf6_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
{
@@ -1405,8 +1425,7 @@ struct XdlopsGemm
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(KPack * 2 % mfma_instr.k_per_blk == 0,
"KPack should be a multiple of k_per_blk");
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
}
// XDL output supporting C = A * B

View File

@@ -1037,6 +1037,54 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
#endif
}
template <class FloatC>
__device__ static void Run(const f6x16x2_t& reg_a,
const int32_t scale_a,
const f6x16x2_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
using arg_type = int32x8_t;
arg_type arg_a{
static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][0]),
static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][1]),
static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][2]),
static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][0]),
static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][1]),
static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][2]),
0,
0};
arg_type arg_b{
static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][0]),
static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][1]),
static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][2]),
static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][0]),
static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][1]),
static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][2]),
0,
0};
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_a,
arg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
OpselA, // OPSEL
scale_a,
OpselB, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf6x32_t& reg_a,
const int32_t scale_a,

View File

@@ -67,27 +67,42 @@ struct f6_pk_t
{
using element_type = uint32_t; // element storage fundamental type
static constexpr index_t packed_size = pk_size;
static constexpr index_t num_bits_elem = 6;
static constexpr index_t num_bits_vec_elem = sizeof(element_type) * CHAR_BIT;
static constexpr index_t packed_size = pk_size; // 16 or 32 for now
static constexpr index_t num_bits_elem = 6; // specialized for 6-bit data
// XXX: CHAR_BIT is not defined in HIPRTC, so we must use 8
static constexpr index_t num_bits_vec_elem =
sizeof(element_type) * 8; // 32-bit uint for storage
static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0,
"Packed elements must fit exactly into the element storage.");
static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem;
static constexpr index_t vector_size =
(packed_size * num_bits_elem) / num_bits_vec_elem; // 3 or 6 element_type units
using storage_type = StaticallyIndexedArray_v2<element_type, vector_size>;
storage_type data; // packed data
using storage_type = element_type __attribute__((ext_vector_type(vector_size)));
storage_type data_{storage_type(0)}; // packed data
using type = f6_pk_t<BitType, packed_size>;
__host__ __device__ constexpr f6_pk_t() : data{} {}
__host__ __device__ constexpr f6_pk_t(storage_type init) : data{init} {}
__host__ __device__ constexpr f6_pk_t() {}
__host__ __device__ constexpr f6_pk_t(const storage_type& init) : data_{init}
{
// TODO: consider removing initialization similar to vector_type<T, 256>
}
// Initialize from a vector type with the same size as packed_size
template <typename T, typename = enable_if_t<scalar_type<T>::vector_size == packed_size>>
__host__ __device__ f6_pk_t(const T& v) : data{}
__host__ __device__ f6_pk_t(const T& v)
{
static_for<0, packed_size, 1>{}(
[&](auto i) { pack(v[static_cast<index_t>(i)], static_cast<index_t>(i)); });
}
// Broadcast single initialization value to all packed elements
__host__ __device__ f6_pk_t(const int8_t v)
: f6_pk_t(static_cast<int8_t __attribute__((ext_vector_type(packed_size)))>(v))
{
// TODO: consider removing initialization similar to vector_type<T, 256>
}
template <typename T>
__host__ __device__ void pack(const T x, const index_t i)
{
@@ -99,18 +114,18 @@ struct f6_pk_t
const int arr_index = bit_pos / num_bits_vec_elem;
const int bit_offset = bit_pos % num_bits_vec_elem;
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = data.data_[arr_index];
uint32_t old_value = data_[arr_index];
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
data.data_[arr_index] = old_value;
data_[arr_index] = old_value;
// if it crosses into the next block, shift the remainder
if(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = data.data_[arr_index + 1];
uint32_t next_value = data_[arr_index + 1];
next_value |= (bits >> (num_bits_elem - overhang));
data.data_[arr_index + 1] = next_value;
data_[arr_index + 1] = next_value;
}
}
@@ -121,17 +136,33 @@ struct f6_pk_t
const int bit_offset = bit_pos % num_bits_vec_elem;
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t bits = pk.data.data_[arr_idx] >> bit_offset;
uint32_t bits = pk.data_[arr_idx] >> bit_offset;
if(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (pk.data.data_[arr_idx + 1] & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang);
}
return static_cast<BitType>(bits & 0x3F);
}
__host__ __device__ inline BitType unpack(const index_t i) const { return unpack(*this, i); }
// Compare operator
__host__ __device__ friend bool operator==(const f6_pk_t& lhs, const f6_pk_t& rhs)
{
#pragma unroll
for(index_t i = 0; i < vector_size; ++i)
{
if(lhs.data_[i] != rhs.data_[i])
return false;
}
return true;
}
__host__ __device__ friend bool operator!=(const f6_pk_t& lhs, const f6_pk_t& rhs)
{
return !(lhs == rhs);
}
};
using f6x16_pk_t = f6_pk_t<f6_t, 16>;
@@ -296,6 +327,34 @@ struct scalar_type<f4x2_pk_t>
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<f6x32_pk_t>
{
using type = f6x32_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf6x32_pk_t>
{
using type = bf6x32_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<f6x16_pk_t>
{
using type = f6x16_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf6x16_pk_t>
{
using type = bf6x16_pk_t::storage_type;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bool>
{

View File

@@ -1438,14 +1438,16 @@ struct non_native_vector_base<
// implementation for f6x16 and f6x32
template <typename T, index_t N>
struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>
struct non_native_vector_base<
T,
N,
ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 16 || sizeof(T) == 24 || sizeof(T) == 32>>
{
using data_t =
typename nnvb_data_t_selector<T>::type; // select data_t based on declared base type
using element_t = typename T::element_type; // select element_t based on declared element type
static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch");
static constexpr size_t size_factor =
sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6
static constexpr size_t size_factor = sizeof(data_t) / sizeof(element_t);
using data_v = element_t __attribute__((ext_vector_type(N * size_factor)));
using type = non_native_vector_base<T, N>;
@@ -1457,29 +1459,29 @@ struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T)
StaticallyIndexedArray<data_v, 1> dNx1;
} data_;
__host__ __device__ constexpr non_native_vector_base(data_t a)
: data_{data_v(a.At(Number<0>{}))}
// Broadcast single value to vector
__host__ __device__ constexpr non_native_vector_base(data_t a) : data_{}
{
// TODO: consider removing initialization similar to vector_type<T, 256>
ck::static_for<0, N, 1>{}([&](auto i) {
data_.dxN(i) = a; // broadcast value to all elements
});
}
__host__ __device__ constexpr non_native_vector_base(T f)
: non_native_vector_base(bit_cast<data_t>(f))
{
}
__host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){};
__host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {}
__host__ __device__ constexpr non_native_vector_base(element_t v) : data_{data_v(v)} {}
__host__ __device__ constexpr operator data_v() const { return data_.dN; }
__host__ __device__ constexpr operator data_t() const
{
if constexpr(N == 1)
{
return data_.dxN[Number<0>{}];
}
else
{
return data_.dxN; // XXX this should cause an error
}
}
__host__ __device__ constexpr operator T() const
{
if constexpr(N == 1)
@@ -1488,7 +1490,31 @@ struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T)
}
else
{
return data_.dTxN; // XXX this should cause an error
return err; // XXX this should cause an error
}
}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same_v<X, data_t> || is_same_v<X, data_v> || is_same_v<X, T>,
"Something went wrong, please check src and dst types.");
if constexpr(is_same_v<X, data_v>)
{
return data_.dNx1;
}
else if constexpr(is_same_v<X, data_t>)
{
return data_.dxN;
}
else if constexpr(is_same_v<X, T>)
{
return data_.dTxN;
}
else
{
return err;
}
}
};
@@ -1504,8 +1530,10 @@ struct scalar_type<non_native_vector_base<
};
template <typename T, index_t N>
struct scalar_type<
non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>>
struct scalar_type<non_native_vector_base<
T,
N,
ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 16 || sizeof(T) == 24 || sizeof(T) == 32>>>
{
using type = typename non_native_vector_base<T, N>::element_t;
static constexpr index_t vector_size = N * non_native_vector_base<T, N>::size_factor;
@@ -2221,8 +2249,9 @@ using f4x32_t = typename vector_type<f4x2_pk_t, 16>::type;
using f4x64_t = typename vector_type<f4x2_pk_t, 32>::type;
// f6
using f6x16_t = typename vector_type<f6x16_pk_t, 1>::type;
using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
using f6x16_t = typename vector_type<f6x16_pk_t, 1>::type;
using f6x16x2_t = typename vector_type<f6x16_pk_t, 2>::type;
using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
// bf6
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;

View File

@@ -34,6 +34,10 @@ struct DynamicBuffer
ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0};
// XXX: PackedSize semantics for pk_i4_t is different from the other packed types.
// Objects of f4x2_pk_t and f6_pk_t are counted as 1 element, while
// objects of pk_i4_t are counted as 2 elements. Therefore, element_space_size_ for pk_i4_t must
// be divided by 2 to correctly represent the number of addressable elements.
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
return 2;

View File

@@ -501,8 +501,8 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc
float float_array[32];
} out{};
out.float_vector =
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(in.f6_vector, type_convert<float>(scale));
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
in.f6_vector.template AsType<f6x32_t::data_t>()[Number<0>{}], type_convert<float>(scale));
return out.float_array[0];
#else
return utils::to_float<f6_t>(scale, x);
@@ -522,7 +522,8 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f6x32_t>(e8m
f6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(x, type_convert<float>(scale));
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
x.template AsType<f6x32_t::data_t>()[Number<0>{}], type_convert<float>(scale));
#else
union
{
@@ -567,8 +568,8 @@ inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t s
float float_array[32];
} out{};
out.float_vector =
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(in.bf6_vector, type_convert<float>(scale));
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
in.bf6_vector.template AsType<bf6x32_t::data_t>()[Number<0>{}], type_convert<float>(scale));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(scale, x);
@@ -588,7 +589,8 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, bf6x32_t>(e8
bf6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(x, type_convert<float>(scale));
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
x.template AsType<bf6x32_t::data_t>()[Number<0>{}], type_convert<float>(scale));
#else
union
{

View File

@@ -1734,7 +1734,7 @@ inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f)
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale);
out.f6_vector = f6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale)};
return out.f6_array[0];
#else
@@ -1757,7 +1757,7 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0
#if defined(__gfx950__)
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
return __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale);
return f6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale)};
#else
union
{
@@ -1765,17 +1765,15 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0
float float_array[32];
} in{x};
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} out{};
using array_type = uint8_t __attribute__((ext_vector_type(32)));
array_type uint8_array;
// collect the 6-bit values into an array
ck::static_for<0, 32, 1>{}([&](auto i) {
out.f6_array[i] = utils::sat_convert_to_type<f6_t>(in.float_array[i] / scale);
uint8_array[static_cast<index_t>(i)] =
utils::sat_convert_to_type<f6_t>(in.float_array[i] / scale);
});
return out.f6_vector;
return f6x32_t{f6x32_pk_t{uint8_array}};
#endif
}
@@ -1807,7 +1805,8 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
f6_t f6_array[32];
} out{};
out.f6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale);
out.f6_vector =
f6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale)};
return out.f6_array[0];
#else
@@ -1837,7 +1836,7 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
return f6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale)};
#else
constexpr int seed = 1254739;
union
@@ -1852,6 +1851,7 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f
uint32_t rng =
prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
#endif
union
{
float32_t float_vector;
@@ -1914,6 +1914,43 @@ inline __host__ __device__ f6x32_t type_convert<f6x32_t, float32_t>(float32_t x)
#endif
}
template <>
inline __host__ __device__ f6x32_pk_t type_convert<f6x32_pk_t, float32_t>(float32_t x)
{
return static_cast<f6x32_pk_t>(type_convert<f6x32_t>(x));
}
template <>
inline __host__ __device__ f6x16_t type_convert<f6x16_t, float16_t>(float16_t x)
{
union
{
float16_t v16x2[2];
float32_t v32;
} in{{x, x}};
union
{
f6x32_t v32;
f6x16_t v16x2[2];
} out{};
#if CK_USE_SR_F6_CONVERSION
out.v32 = f6_convert_sr(in.v32);
#else
out.v32 = f6_convert_rne(in.v32);
#endif
return out.v16x2[0];
}
template <>
inline __host__ __device__ f6x16_pk_t type_convert<f6x16_pk_t, float16_t>(float16_t x)
{
return static_cast<f6x16_pk_t>(type_convert<f6x16_t>(x));
}
/**
* @brief Specializes the type conversion template for converting the 6-bit float type (f6_t) to
* float.
@@ -1929,9 +1966,9 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
#if defined(__gfx950__)
union
{
f6x32_t f6_vector;
f6_t f6_array[32];
} in{x};
f6x32_t f6_vector;
} in{{x}};
union
{
@@ -1940,7 +1977,8 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
in.f6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
in.f6_vector.template AsType<f6x32_t::data_t>()[Number<0>{}],
type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
@@ -1948,8 +1986,8 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
}
/**
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
* (f6x32_t) to vector of 32 floats.
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float
* types (f6x32_t) to vector of 32 floats.
*
* Interprets an f6_t values as floats using the default scale factor of 1.
*
@@ -1961,7 +1999,8 @@ inline __host__ __device__ float32_t type_convert<float32_t, f6x32_t>(f6x32_t x)
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
x.template AsType<f6x32_t::data_t>()[Number<0>{}],
type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{
@@ -1984,6 +2023,31 @@ inline __host__ __device__ float32_t type_convert<float32_t, f6x32_t>(f6x32_t x)
#endif
}
template <>
inline __host__ __device__ float16_t type_convert<float16_t, f6x16_t>(f6x16_t x)
{
union
{
f6x16_t v16x2[2];
f6x32_t v32;
} in{{x, x}};
union
{
float16_t v16x2[2];
float32_t v32;
} out{};
out.v32 = type_convert<float32_t>(in.v32);
return out.v16x2[0];
}
template <>
inline __host__ __device__ float16_t type_convert<float16_t, f6x16_pk_t>(f6x16_pk_t x)
{
return type_convert<float16_t>(static_cast<f6x16_t>(x));
}
/**
* @brief Converts a float to the 6-bit BF6 type using round-to-nearest-even.
*
@@ -2006,7 +2070,7 @@ inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f)
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale);
out.bf6_vector = bf6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale)};
return out.bf6_array[0];
#else
@@ -2030,7 +2094,7 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1
#if defined(__gfx950__)
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
return __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale);
return bf6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale)};
#else
union
{
@@ -2081,7 +2145,8 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
bf6_t bf6_array[32];
} out{};
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale);
out.bf6_vector =
bf6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale)};
return out.bf6_array[0];
#else
@@ -2113,7 +2178,7 @@ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.
// use HW clock for stochastic input multiply by incremented thread id
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
(get_thread_global_1d_id() + 1));
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
return bf6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale)};
#else
constexpr int seed = 1254739;
union
@@ -2186,6 +2251,12 @@ inline __host__ __device__ bf6x32_t type_convert<bf6x32_t, float32_t>(float32_t
#endif
}
template <>
inline __host__ __device__ bf6x32_pk_t type_convert<bf6x32_pk_t, float32_t>(float32_t x)
{
return static_cast<bf6x32_pk_t>(type_convert<bf6x32_t>(x));
}
/**
* @brief Specializes the type conversion template for converting a bf6_t value to float.
*
@@ -2201,9 +2272,9 @@ inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
#if defined(__gfx950__)
union
{
bf6x32_t bf6_vector;
bf6_t bf6_array[32];
} in{x};
bf6x32_t bf6_vector;
} in{{x}};
union
{
@@ -2212,7 +2283,8 @@ inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
} out{};
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
in.bf6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
in.bf6_vector.template AsType<bf6x32_t::data_t>()[Number<0>{}],
type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
return out.float_array[0];
#else
return utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
@@ -2234,7 +2306,8 @@ inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t
{
#if defined(__gfx950__)
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
x.template AsType<bf6x32_t::data_t>()[Number<0>{}],
type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
#else
union
{

View File

@@ -53,6 +53,7 @@ if(GPU_TARGETS MATCHES "gfx950")
add_gtest_executable(test_fp6 test_fp6.cpp)
if(result EQUAL 0)
target_compile_options(test_fp6 PRIVATE -mavx512f)
target_link_libraries(test_fp6 PRIVATE utility)
endif()
add_dependencies(test_mx_data_types test_fp6)

View File

@@ -228,8 +228,8 @@ TEST(BF6, ScaledConvertFP32Stochastic)
TEST(BF6, TestSize)
{
ASSERT_EQ(1, sizeof(bf6_t));
ASSERT_EQ(12, sizeof(bf6x16_pk_t));
ASSERT_EQ(24, sizeof(bf6x32_pk_t));
ASSERT_EQ(16, sizeof(bf6x16_pk_t));
ASSERT_EQ(32, sizeof(bf6x32_pk_t));
ASSERT_EQ(16, sizeof(vector_type<bf6x16_pk_t, 1>));
ASSERT_EQ(32, sizeof(vector_type<bf6x16_pk_t, 2>));
ASSERT_EQ(32, sizeof(vector_type<bf6x32_pk_t, 1>));
@@ -238,8 +238,8 @@ TEST(BF6, TestSize)
TEST(BF6, TestAlignment)
{
ASSERT_EQ(1, alignof(bf6_t));
ASSERT_EQ(4, alignof(bf6x16_pk_t));
ASSERT_EQ(4, alignof(bf6x32_pk_t));
ASSERT_EQ(16, alignof(bf6x16_pk_t));
ASSERT_EQ(32, alignof(bf6x32_pk_t));
ASSERT_EQ(16, alignof(vector_type<bf6x16_pk_t, 1>));
ASSERT_EQ(32, alignof(vector_type<bf6x16_pk_t, 2>));
ASSERT_EQ(32, alignof(vector_type<bf6x32_pk_t, 1>));

View File

@@ -6,6 +6,7 @@
#include "ck/utility/type_convert.hpp"
#include "ck/utility/env.hpp"
#include "ck/utility/scaled_type_convert.hpp"
#include "ck/library/utility/device_memory.hpp"
using ck::e8m0_bexp_t;
using ck::f6_convert_rne;
@@ -227,8 +228,8 @@ TEST(FP6, ScaledConvertFP32Stochastic)
TEST(FP6, TestSize)
{
ASSERT_EQ(1, sizeof(f6_t));
ASSERT_EQ(12, sizeof(f6x16_pk_t));
ASSERT_EQ(24, sizeof(f6x32_pk_t));
ASSERT_EQ(16, sizeof(f6x16_pk_t));
ASSERT_EQ(32, sizeof(f6x32_pk_t));
ASSERT_EQ(16, sizeof(vector_type<f6x16_pk_t, 1>));
ASSERT_EQ(32, sizeof(vector_type<f6x16_pk_t, 2>));
ASSERT_EQ(32, sizeof(vector_type<f6x32_pk_t, 1>));
@@ -237,8 +238,8 @@ TEST(FP6, TestSize)
TEST(FP6, TestAlignment)
{
ASSERT_EQ(1, alignof(f6_t));
ASSERT_EQ(4, alignof(f6x16_pk_t));
ASSERT_EQ(4, alignof(f6x32_pk_t));
ASSERT_EQ(16, alignof(f6x16_pk_t));
ASSERT_EQ(32, alignof(f6x32_pk_t));
ASSERT_EQ(16, alignof(vector_type<f6x16_pk_t, 1>));
ASSERT_EQ(32, alignof(vector_type<f6x16_pk_t, 2>));
ASSERT_EQ(32, alignof(vector_type<f6x32_pk_t, 1>));
@@ -292,6 +293,60 @@ TEST(FP6, TestAsType16x1)
});
}
__global__ void test_f6_convert_rne(float* p_test, uint64_t* p_completed)
{
constexpr int N = 32;
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
ck::float32_t float32_in(1.0f);
ck::float32_t float32_out{};
auto f6x32_vec = f6_convert_rne(float32_in);
float32_out = type_convert<ck::float32_t>(f6x32_vec);
ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32_out[static_cast<int>(ii)]; });
i = N;
}
TEST(MXFP6, DeviceF6ConvertRNE)
{
constexpr int N = 32;
std::vector<float> out(N, -1.0f);
DeviceMem device_out(N * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
test_f6_convert_rne<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
EXPECT_EQ(N, completed);
ck::static_for<0, N, 1>{}(
[&](auto ii) { EXPECT_EQ(out[static_cast<int>(ii)], 1.0f) << "ii: " << ii << std::endl; });
auto f6x32_vec_tc = ck::type_convert<f6x32_pk_t>(ck::float32_t(1.0f));
auto f6x32_vec_cnstr = f6x32_pk_t(0x08);
EXPECT_EQ(f6x32_vec_tc, f6x32_vec_cnstr);
}
// test vector of 2 f6x16_pk_t, contains 32 f6_t
TEST(FP6, TestAsType16x2)
{