From 2325a9fe3a033dfe5e68206a742211ddb7e804b9 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Mon, 7 Jul 2025 10:33:26 -0600 Subject: [PATCH] MX GEMM - FP6 Example (#2419) Adds support for MX FP6 data type in MX GEMM block pipeline version v1. Provides an example of MX FP6 GEMM algorithm. --------- Co-authored-by: OscarXu Co-authored-by: aska-0096 Co-authored-by: mtgu0705 Co-authored-by: Your Name Co-authored-by: lalala-sh Co-authored-by: valarLip <340077269@qq.com> Co-authored-by: Ding, Yi Co-authored-by: feifei14119 Co-authored-by: Lin, Qun Co-authored-by: joye [ROCm/composable_kernel commit: 054f85ab7c0fa07a90968e834899ec415af8b713] --- CHANGELOG.md | 2 +- example/67_gemm_microscaling/CMakeLists.txt | 7 + .../67_gemm_microscaling/gemm_mx_common.hpp | 38 ++++-- example/67_gemm_microscaling/gemm_mx_fp6.cpp | 99 ++++++++++++++ include/ck/library/utility/host_tensor.hpp | 58 ++++++++ ...blockwise_gemm_mx_pipeline_xdlops_base.hpp | 9 +- .../impl/device_gemm_xdl_cshuffle_v3_mx.hpp | 4 +- .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 6 + .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 25 +++- include/ck/utility/amd_xdlops.hpp | 48 +++++++ include/ck/utility/data_type.hpp | 91 ++++++++++--- include/ck/utility/dtype_vector.hpp | 71 +++++++--- include/ck/utility/dynamic_buffer.hpp | 4 + include/ck/utility/scaled_type_convert.hpp | 14 +- include/ck/utility/type_convert.hpp | 125 ++++++++++++++---- test/data_type/CMakeLists.txt | 1 + test/data_type/test_bf6.cpp | 8 +- test/data_type/test_fp6.cpp | 63 ++++++++- 18 files changed, 578 insertions(+), 95 deletions(-) create mode 100644 example/67_gemm_microscaling/gemm_mx_fp6.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f04935b8d..86a426e321 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM * Added support for Multiple D GEMM -* Added GEMM pipeline for microscaling (MX) FP8/FP4 data types +* Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. * Added logit soft-capping support for fMHA forward kernels. diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 07315d4aa5..35c5d18d50 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -10,6 +10,9 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_bf8) # add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp) # add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8) +add_example_executable(example_gemm_mx_fp6 gemm_mx_fp6.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp6) + add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp) add_example_dependencies(example_gemm_mx example_gemm_mx_fp4) @@ -55,3 +58,7 @@ set(FP8_MXGEMM_OPTIONS) list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32") example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS}) example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS}) + +set(FP6_MXGEMM_OPTIONS) +list(APPEND FP6_MXGEMM_OPTIONS -mavx512f) +example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS}) diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 1f01e1c7be..6ce10817ff 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -245,6 +245,11 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); }; + if(K % ck::packed_size_v != 0 || K % ck::packed_size_v != 0) + { + throw std::runtime_error("wrong! K must be multiple of packed size."); + }; + // Hardcode scale layouts as per pipeline assumptions // TODO: Allow user to specify scale layouts using AScaleLayout = Row; @@ -292,12 +297,20 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto a_data_element = [](float x) { if constexpr(ck::is_same_v) return ck::type_convert(ck::float2_t(x)); + else if constexpr(ck::packed_size_v == 32) + return ck::type_convert(ck::float32_t(x)); + else if constexpr(ck::packed_size_v == 16) + return ck::type_convert(ck::float16_t(x)); else return ck::type_convert(x); }; auto b_data_element = [](float x) { if constexpr(ck::is_same_v) return ck::type_convert(ck::float2_t(x)); + else if constexpr(ck::packed_size_v == 32) + return ck::type_convert(ck::float32_t(x)); + else if constexpr(ck::packed_size_v == 16) + return ck::type_convert(ck::float16_t(x)); else return ck::type_convert(x); }; @@ -307,30 +320,35 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c switch(config.init_method) { case 0: // Initializations for development and debugging - ck::utils::FillConstant{a_data_element(1.0f)}(a_m_k); - ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k_scale); + + ck::utils::FillConstant{a_data_element(0.5f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(2.0f)}(a_m_k_scale); + ck::utils::FillConstant{b_data_element(2.0f)}(*b_k_n); ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n_scale); + if(config.verbosity > 0) { - std::cout << "Init A = {1}" << std::endl; + std::cout << "Init A = {0.5}" << std::endl; std::cout << "Init A scale = {2.0}" << std::endl; - std::cout << "Init B = {0.5}" << std::endl; - std::cout << "Init B scale = {1.0}" << std::endl; + std::cout << "Init B = {2.0}" << std::endl; + std::cout << "Init B scale = {0.5}" << std::endl; std::cout << "Expect C = {K}" << std::endl; } break; case 1: - a_m_k.GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5] - b_k_n->GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5] + a_m_k.GenerateTensorDistr( + int_distr{-5, 5}, ck::identity{}, std::minstd_rand(time(nullptr))); // Z[-5,5] + b_k_n->GenerateTensorDistr(int_distr{-5, 5}); // Z[-5,5] static_assert(ck::is_same_v); - a_m_k_scale.GenerateTensorDistr(int_distr{120, 129}); // scales: {0.25, 0.5, 1, 2} - b_k_n_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2} + a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2} break; case 2: - a_m_k.GenerateTensorDistr(float_distr{-2.0, 2.0}); + a_m_k.GenerateTensorDistr( + float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr))); // R[-2,2] a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f}); b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0}); diff --git a/example/67_gemm_microscaling/gemm_mx_fp6.cpp b/example/67_gemm_microscaling/gemm_mx_fp6.cpp new file mode 100644 index 0000000000..615980082d --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp6.cpp @@ -0,0 +1,99 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f6x16_pk_t; +using BDataType = ck::f6x16_pk_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / ck::packed_size_v; // K dimension size per block + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Number of threads per block + 128, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 1, // AK1 number of elements to read at a time when transferring from global memory to LDS + 1, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 286dffc36c..46028b79f9 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -556,6 +556,64 @@ struct Tensor return ck::f4x2_pk_t{ck::type_convert( ck::float2_t{ck::type_convert(fn(dis_(g_))), ck::type_convert(fn(dis_(g_)))})}; + else if constexpr(ck::is_same_v || + ck::is_same_v) + { + return ck::type_convert( + ck::float32_t{ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_)))}); + } + else if constexpr(ck::is_same_v || + ck::is_same_v) + { + return ck::type_convert( + ck::float16_t{ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_)))}); + } else static_assert(false, "Unsupported packed size for T"); }; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp index 5370cfa975..c929956124 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp @@ -66,9 +66,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base static constexpr index_t AMmaKStride = KPack; static constexpr index_t BMmaKStride = KPack; - //> store rows/cols into thread registers in chunks of 16 - //> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47] - static constexpr index_t KThreadChunk = 16 / sizeof(ComputeTypeA); + // store rows/cols into thread registers in chunks of 16 for FP8 + // e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47] + // or in chunks of 32 / APackedSize for FP6/FP4 + static constexpr index_t KThreadChunk = (APackedSize == 1) ? 16 : 32 / APackedSize; + + static_assert(APackedSize == BPackedSize, "APackedSize must be equal to BPackedSize for now"); static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KRepeat = KPerThread / KPack; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index ed168195ec..ae9b75cb0d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -54,6 +54,8 @@ namespace device { * * Conditions for achieving computational load balancing on different hardware platforms can vary. * + * \tparam KPerBlock is the number of elements in K dimension that each block processes (multiply with packed_size_v to get the actual KPerBlock) + * * Serialized version of the algorithm: * \code * // E = A * B + C @@ -117,7 +119,7 @@ template , f6x16_pk_t> || + is_same_v, bf6x16_pk_t> || + is_same_v, f6x32_pk_t> || + is_same_v, bf6x32_pk_t>)&&GemmSpec != + GemmSpecialization::Default), + "Packed F6 types do not support padding"); if constexpr(GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding) diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 1dd766eca0..64d7f92750 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -889,7 +889,6 @@ struct mfma_type const ScaleB& scale_b, FloatC& reg_c) const { - intrin_mfma_scale_f32_32x32x64f8f6f4::Run( a, bit_cast(scale_a), b, bit_cast(scale_b), reg_c); } @@ -1224,6 +1223,27 @@ struct MfmaSelector return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + template <> constexpr auto GetMfma() { @@ -1405,8 +1425,7 @@ struct XdlopsGemm MPerXdlops == 64, "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); - static_assert(KPack * 2 % mfma_instr.k_per_blk == 0, - "KPack should be a multiple of k_per_blk"); + static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk"); } // XDL output supporting C = A * B diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 56da5c1dc8..efb877b3f2 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1037,6 +1037,54 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> #endif } + template + __device__ static void Run(const f6x16x2_t& reg_a, + const int32_t scale_a, + const f6x16x2_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + using arg_type = int32x8_t; + arg_type arg_a{ + static_cast(reg_a.template AsType()[Number<0>{}][0]), + static_cast(reg_a.template AsType()[Number<0>{}][1]), + static_cast(reg_a.template AsType()[Number<0>{}][2]), + static_cast(reg_a.template AsType()[Number<1>{}][0]), + static_cast(reg_a.template AsType()[Number<1>{}][1]), + static_cast(reg_a.template AsType()[Number<1>{}][2]), + 0, + 0}; + arg_type arg_b{ + static_cast(reg_b.template AsType()[Number<0>{}][0]), + static_cast(reg_b.template AsType()[Number<0>{}][1]), + static_cast(reg_b.template AsType()[Number<0>{}][2]), + static_cast(reg_b.template AsType()[Number<1>{}][0]), + static_cast(reg_b.template AsType()[Number<1>{}][1]), + static_cast(reg_b.template AsType()[Number<1>{}][2]), + 0, + 0}; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_a, + arg_b, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const bf6x32_t& reg_a, const int32_t scale_a, diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 51da18cd2b..15b8841c39 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -67,27 +67,42 @@ struct f6_pk_t { using element_type = uint32_t; // element storage fundamental type - static constexpr index_t packed_size = pk_size; - static constexpr index_t num_bits_elem = 6; - static constexpr index_t num_bits_vec_elem = sizeof(element_type) * CHAR_BIT; + static constexpr index_t packed_size = pk_size; // 16 or 32 for now + static constexpr index_t num_bits_elem = 6; // specialized for 6-bit data + // XXX: CHAR_BIT is not defined in HIPRTC, so we must use 8 + static constexpr index_t num_bits_vec_elem = + sizeof(element_type) * 8; // 32-bit uint for storage static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0, "Packed elements must fit exactly into the element storage."); - static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; + static constexpr index_t vector_size = + (packed_size * num_bits_elem) / num_bits_vec_elem; // 3 or 6 element_type units - using storage_type = StaticallyIndexedArray_v2; - storage_type data; // packed data + using storage_type = element_type __attribute__((ext_vector_type(vector_size))); + storage_type data_{storage_type(0)}; // packed data using type = f6_pk_t; - __host__ __device__ constexpr f6_pk_t() : data{} {} - __host__ __device__ constexpr f6_pk_t(storage_type init) : data{init} {} + __host__ __device__ constexpr f6_pk_t() {} + __host__ __device__ constexpr f6_pk_t(const storage_type& init) : data_{init} + { + // TODO: consider removing initialization similar to vector_type + } + + // Initialize from a vector type with the same size as packed_size template ::vector_size == packed_size>> - __host__ __device__ f6_pk_t(const T& v) : data{} + __host__ __device__ f6_pk_t(const T& v) { static_for<0, packed_size, 1>{}( [&](auto i) { pack(v[static_cast(i)], static_cast(i)); }); } + // Broadcast single initialization value to all packed elements + __host__ __device__ f6_pk_t(const int8_t v) + : f6_pk_t(static_cast(v)) + { + // TODO: consider removing initialization similar to vector_type + } + template __host__ __device__ void pack(const T x, const index_t i) { @@ -99,18 +114,18 @@ struct f6_pk_t const int arr_index = bit_pos / num_bits_vec_elem; const int bit_offset = bit_pos % num_bits_vec_elem; const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = data.data_[arr_index]; + uint32_t old_value = data_[arr_index]; // insert bits into the current 32-bit block old_value |= (bits << bit_offset); - data.data_[arr_index] = old_value; + data_[arr_index] = old_value; // if it crosses into the next block, shift the remainder if(overhang > 0 && (arr_index + 1) < vector_size) { - uint32_t next_value = data.data_[arr_index + 1]; + uint32_t next_value = data_[arr_index + 1]; next_value |= (bits >> (num_bits_elem - overhang)); - data.data_[arr_index + 1] = next_value; + data_[arr_index + 1] = next_value; } } @@ -121,17 +136,33 @@ struct f6_pk_t const int bit_offset = bit_pos % num_bits_vec_elem; const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t bits = pk.data.data_[arr_idx] >> bit_offset; + uint32_t bits = pk.data_[arr_idx] >> bit_offset; if(overhang > 0 && (arr_idx + 1) < vector_size) { - bits |= (pk.data.data_[arr_idx + 1] & ((1u << overhang) - 1)) - << (num_bits_elem - overhang); + bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang); } return static_cast(bits & 0x3F); } __host__ __device__ inline BitType unpack(const index_t i) const { return unpack(*this, i); } + + // Compare operator + __host__ __device__ friend bool operator==(const f6_pk_t& lhs, const f6_pk_t& rhs) + { +#pragma unroll + for(index_t i = 0; i < vector_size; ++i) + { + if(lhs.data_[i] != rhs.data_[i]) + return false; + } + return true; + } + + __host__ __device__ friend bool operator!=(const f6_pk_t& lhs, const f6_pk_t& rhs) + { + return !(lhs == rhs); + } }; using f6x16_pk_t = f6_pk_t; @@ -296,6 +327,34 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +template <> +struct scalar_type +{ + using type = f6x32_pk_t::storage_type; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bf6x32_pk_t::storage_type; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = f6x16_pk_t::storage_type; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bf6x16_pk_t::storage_type; + static constexpr index_t vector_size = 1; +}; + template <> struct scalar_type { diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 0891a7ccf4..effe445883 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -1438,14 +1438,16 @@ struct non_native_vector_base< // implementation for f6x16 and f6x32 template -struct non_native_vector_base> +struct non_native_vector_base< + T, + N, + ck::enable_if_t> { using data_t = typename nnvb_data_t_selector::type; // select data_t based on declared base type using element_t = typename T::element_type; // select element_t based on declared element type static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); - static constexpr size_t size_factor = - sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6 + static constexpr size_t size_factor = sizeof(data_t) / sizeof(element_t); using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); using type = non_native_vector_base; @@ -1457,29 +1459,29 @@ struct non_native_vector_base dNx1; } data_; - __host__ __device__ constexpr non_native_vector_base(data_t a) - : data_{data_v(a.At(Number<0>{}))} + // Broadcast single value to vector + __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{} { + // TODO: consider removing initialization similar to vector_type + + ck::static_for<0, N, 1>{}([&](auto i) { + data_.dxN(i) = a; // broadcast value to all elements + }); } + __host__ __device__ constexpr non_native_vector_base(T f) : non_native_vector_base(bit_cast(f)) { } + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + __host__ __device__ constexpr non_native_vector_base(element_t v) : data_{data_v(v)} {} + __host__ __device__ constexpr operator data_v() const { return data_.dN; } - __host__ __device__ constexpr operator data_t() const - { - if constexpr(N == 1) - { - return data_.dxN[Number<0>{}]; - } - else - { - return data_.dxN; // XXX this should cause an error - } - } + __host__ __device__ constexpr operator T() const { if constexpr(N == 1) @@ -1488,7 +1490,31 @@ struct non_native_vector_base + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dNx1; + } + else if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else + { + return err; } } }; @@ -1504,8 +1530,10 @@ struct scalar_type -struct scalar_type< - non_native_vector_base>> +struct scalar_type>> { using type = typename non_native_vector_base::element_t; static constexpr index_t vector_size = N * non_native_vector_base::size_factor; @@ -2221,8 +2249,9 @@ using f4x32_t = typename vector_type::type; using f4x64_t = typename vector_type::type; // f6 -using f6x16_t = typename vector_type::type; -using f6x32_t = typename vector_type::type; +using f6x16_t = typename vector_type::type; +using f6x16x2_t = typename vector_type::type; +using f6x32_t = typename vector_type::type; // bf6 using bf6x16_t = typename vector_type::type; diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 2debd09c2d..ed42b22daf 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -34,6 +34,10 @@ struct DynamicBuffer ElementSpaceSize element_space_size_; T invalid_element_value_ = T{0}; + // XXX: PackedSize semantics for pk_i4_t is different from the other packed types. + // Objects of f4x2_pk_t and f6_pk_t are counted as 1 element, while + // objects of pk_i4_t are counted as 2 elements. Therefore, element_space_size_ for pk_i4_t must + // be divided by 2 to correctly represent the number of addressable elements. static constexpr index_t PackedSize = []() { if constexpr(is_same_v, pk_i4_t>) return 2; diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp index f3e2bd3dd9..90a018fe3a 100644 --- a/include/ck/utility/scaled_type_convert.hpp +++ b/include/ck/utility/scaled_type_convert.hpp @@ -501,8 +501,8 @@ inline __host__ __device__ float scaled_type_convert(e8m0_bexp_t sc float float_array[32]; } out{}; - out.float_vector = - __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(in.f6_vector, type_convert(scale)); + out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6( + in.f6_vector.template AsType()[Number<0>{}], type_convert(scale)); return out.float_array[0]; #else return utils::to_float(scale, x); @@ -522,7 +522,8 @@ inline __host__ __device__ float32_t scaled_type_convert(e8m f6x32_t x) { #if defined(__gfx950__) - return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(x, type_convert(scale)); + return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6( + x.template AsType()[Number<0>{}], type_convert(scale)); #else union { @@ -567,8 +568,8 @@ inline __host__ __device__ float scaled_type_convert(e8m0_bexp_t s float float_array[32]; } out{}; - out.float_vector = - __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(in.bf6_vector, type_convert(scale)); + out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6( + in.bf6_vector.template AsType()[Number<0>{}], type_convert(scale)); return out.float_array[0]; #else return utils::to_float(scale, x); @@ -588,7 +589,8 @@ inline __host__ __device__ float32_t scaled_type_convert(e8 bf6x32_t x) { #if defined(__gfx950__) - return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(x, type_convert(scale)); + return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6( + x.template AsType()[Number<0>{}], type_convert(scale)); #else union { diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 69a953b575..23ab1bebb5 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -1734,7 +1734,7 @@ inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f) f6_t f6_array[32]; } out{}; - out.f6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale); + out.f6_vector = f6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale)}; return out.f6_array[0]; #else @@ -1757,7 +1757,7 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0 #if defined(__gfx950__) float16_t* in1 = reinterpret_cast(&x); float16_t* in2 = reinterpret_cast(&x + 16); - return __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale); + return f6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale)}; #else union { @@ -1765,17 +1765,15 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0 float float_array[32]; } in{x}; - union - { - f6x32_t f6_vector; - f6_t f6_array[32]; - } out{}; + using array_type = uint8_t __attribute__((ext_vector_type(32))); + array_type uint8_array; + // collect the 6-bit values into an array ck::static_for<0, 32, 1>{}([&](auto i) { - out.f6_array[i] = utils::sat_convert_to_type(in.float_array[i] / scale); + uint8_array[static_cast(i)] = + utils::sat_convert_to_type(in.float_array[i] / scale); }); - - return out.f6_vector; + return f6x32_t{f6x32_pk_t{uint8_array}}; #endif } @@ -1807,7 +1805,8 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) f6_t f6_array[32]; } out{}; - out.f6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale); + out.f6_vector = + f6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale)}; return out.f6_array[0]; #else @@ -1837,7 +1836,7 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f // use HW clock for stochastic input multiply by incremented thread id uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * (get_thread_global_1d_id() + 1)); - return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale); + return f6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale)}; #else constexpr int seed = 1254739; union @@ -1852,6 +1851,7 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f uint32_t rng = prand_generator(reinterpret_cast(&x), float_values.float_array[0]); #endif + union { float32_t float_vector; @@ -1914,6 +1914,43 @@ inline __host__ __device__ f6x32_t type_convert(float32_t x) #endif } +template <> +inline __host__ __device__ f6x32_pk_t type_convert(float32_t x) +{ + return static_cast(type_convert(x)); +} + +template <> +inline __host__ __device__ f6x16_t type_convert(float16_t x) +{ + + union + { + float16_t v16x2[2]; + float32_t v32; + } in{{x, x}}; + + union + { + f6x32_t v32; + f6x16_t v16x2[2]; + } out{}; + +#if CK_USE_SR_F6_CONVERSION + out.v32 = f6_convert_sr(in.v32); +#else + out.v32 = f6_convert_rne(in.v32); +#endif + + return out.v16x2[0]; +} + +template <> +inline __host__ __device__ f6x16_pk_t type_convert(float16_t x) +{ + return static_cast(type_convert(x)); +} + /** * @brief Specializes the type conversion template for converting the 6-bit float type (f6_t) to * float. @@ -1929,9 +1966,9 @@ inline __host__ __device__ float type_convert(f6_t x) #if defined(__gfx950__) union { - f6x32_t f6_vector; f6_t f6_array[32]; - } in{x}; + f6x32_t f6_vector; + } in{{x}}; union { @@ -1940,7 +1977,8 @@ inline __host__ __device__ float type_convert(f6_t x) } out{}; out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6( - in.f6_vector, type_convert(NumericLimits::Binary_1())); + in.f6_vector.template AsType()[Number<0>{}], + type_convert(NumericLimits::Binary_1())); return out.float_array[0]; #else return utils::to_float(NumericLimits::Binary_1(), x); @@ -1948,8 +1986,8 @@ inline __host__ __device__ float type_convert(f6_t x) } /** - * @brief Specializes the type conversion template for converting the vector of 32 6-bit float types - * (f6x32_t) to vector of 32 floats. + * @brief Specializes the type conversion template for converting the vector of 32 6-bit float + * types (f6x32_t) to vector of 32 floats. * * Interprets an f6_t values as floats using the default scale factor of 1. * @@ -1961,7 +1999,8 @@ inline __host__ __device__ float32_t type_convert(f6x32_t x) { #if defined(__gfx950__) return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6( - x, type_convert(NumericLimits::Binary_1())); + x.template AsType()[Number<0>{}], + type_convert(NumericLimits::Binary_1())); #else union { @@ -1984,6 +2023,31 @@ inline __host__ __device__ float32_t type_convert(f6x32_t x) #endif } +template <> +inline __host__ __device__ float16_t type_convert(f6x16_t x) +{ + union + { + f6x16_t v16x2[2]; + f6x32_t v32; + } in{{x, x}}; + + union + { + float16_t v16x2[2]; + float32_t v32; + } out{}; + + out.v32 = type_convert(in.v32); + return out.v16x2[0]; +} + +template <> +inline __host__ __device__ float16_t type_convert(f6x16_pk_t x) +{ + return type_convert(static_cast(x)); +} + /** * @brief Converts a float to the 6-bit BF6 type using round-to-nearest-even. * @@ -2006,7 +2070,7 @@ inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f) bf6_t bf6_array[32]; } out{}; - out.bf6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale); + out.bf6_vector = bf6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale)}; return out.bf6_array[0]; #else @@ -2030,7 +2094,7 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1 #if defined(__gfx950__) float16_t* in1 = reinterpret_cast(&x); float16_t* in2 = reinterpret_cast(&x + 16); - return __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale); + return bf6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale)}; #else union { @@ -2081,7 +2145,8 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) bf6_t bf6_array[32]; } out{}; - out.bf6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale); + out.bf6_vector = + bf6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale)}; return out.bf6_array[0]; #else @@ -2113,7 +2178,7 @@ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1. // use HW clock for stochastic input multiply by incremented thread id uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * (get_thread_global_1d_id() + 1)); - return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale); + return bf6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale)}; #else constexpr int seed = 1254739; union @@ -2186,6 +2251,12 @@ inline __host__ __device__ bf6x32_t type_convert(float32_t #endif } +template <> +inline __host__ __device__ bf6x32_pk_t type_convert(float32_t x) +{ + return static_cast(type_convert(x)); +} + /** * @brief Specializes the type conversion template for converting a bf6_t value to float. * @@ -2201,9 +2272,9 @@ inline __host__ __device__ float type_convert(bf6_t x) #if defined(__gfx950__) union { - bf6x32_t bf6_vector; bf6_t bf6_array[32]; - } in{x}; + bf6x32_t bf6_vector; + } in{{x}}; union { @@ -2212,7 +2283,8 @@ inline __host__ __device__ float type_convert(bf6_t x) } out{}; out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6( - in.bf6_vector, type_convert(NumericLimits::Binary_1())); + in.bf6_vector.template AsType()[Number<0>{}], + type_convert(NumericLimits::Binary_1())); return out.float_array[0]; #else return utils::to_float(NumericLimits::Binary_1(), x); @@ -2234,7 +2306,8 @@ inline __host__ __device__ float32_t type_convert(bf6x32_t { #if defined(__gfx950__) return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6( - x, type_convert(NumericLimits::Binary_1())); + x.template AsType()[Number<0>{}], + type_convert(NumericLimits::Binary_1())); #else union { diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 8f6e9a0d15..7e23998f8c 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -53,6 +53,7 @@ if(GPU_TARGETS MATCHES "gfx950") add_gtest_executable(test_fp6 test_fp6.cpp) if(result EQUAL 0) + target_compile_options(test_fp6 PRIVATE -mavx512f) target_link_libraries(test_fp6 PRIVATE utility) endif() add_dependencies(test_mx_data_types test_fp6) diff --git a/test/data_type/test_bf6.cpp b/test/data_type/test_bf6.cpp index 9dbb77454c..25c01076e9 100644 --- a/test/data_type/test_bf6.cpp +++ b/test/data_type/test_bf6.cpp @@ -228,8 +228,8 @@ TEST(BF6, ScaledConvertFP32Stochastic) TEST(BF6, TestSize) { ASSERT_EQ(1, sizeof(bf6_t)); - ASSERT_EQ(12, sizeof(bf6x16_pk_t)); - ASSERT_EQ(24, sizeof(bf6x32_pk_t)); + ASSERT_EQ(16, sizeof(bf6x16_pk_t)); + ASSERT_EQ(32, sizeof(bf6x32_pk_t)); ASSERT_EQ(16, sizeof(vector_type)); ASSERT_EQ(32, sizeof(vector_type)); ASSERT_EQ(32, sizeof(vector_type)); @@ -238,8 +238,8 @@ TEST(BF6, TestSize) TEST(BF6, TestAlignment) { ASSERT_EQ(1, alignof(bf6_t)); - ASSERT_EQ(4, alignof(bf6x16_pk_t)); - ASSERT_EQ(4, alignof(bf6x32_pk_t)); + ASSERT_EQ(16, alignof(bf6x16_pk_t)); + ASSERT_EQ(32, alignof(bf6x32_pk_t)); ASSERT_EQ(16, alignof(vector_type)); ASSERT_EQ(32, alignof(vector_type)); ASSERT_EQ(32, alignof(vector_type)); diff --git a/test/data_type/test_fp6.cpp b/test/data_type/test_fp6.cpp index 6d4aec1d9a..14afe3e2e4 100644 --- a/test/data_type/test_fp6.cpp +++ b/test/data_type/test_fp6.cpp @@ -6,6 +6,7 @@ #include "ck/utility/type_convert.hpp" #include "ck/utility/env.hpp" #include "ck/utility/scaled_type_convert.hpp" +#include "ck/library/utility/device_memory.hpp" using ck::e8m0_bexp_t; using ck::f6_convert_rne; @@ -227,8 +228,8 @@ TEST(FP6, ScaledConvertFP32Stochastic) TEST(FP6, TestSize) { ASSERT_EQ(1, sizeof(f6_t)); - ASSERT_EQ(12, sizeof(f6x16_pk_t)); - ASSERT_EQ(24, sizeof(f6x32_pk_t)); + ASSERT_EQ(16, sizeof(f6x16_pk_t)); + ASSERT_EQ(32, sizeof(f6x32_pk_t)); ASSERT_EQ(16, sizeof(vector_type)); ASSERT_EQ(32, sizeof(vector_type)); ASSERT_EQ(32, sizeof(vector_type)); @@ -237,8 +238,8 @@ TEST(FP6, TestSize) TEST(FP6, TestAlignment) { ASSERT_EQ(1, alignof(f6_t)); - ASSERT_EQ(4, alignof(f6x16_pk_t)); - ASSERT_EQ(4, alignof(f6x32_pk_t)); + ASSERT_EQ(16, alignof(f6x16_pk_t)); + ASSERT_EQ(32, alignof(f6x32_pk_t)); ASSERT_EQ(16, alignof(vector_type)); ASSERT_EQ(32, alignof(vector_type)); ASSERT_EQ(32, alignof(vector_type)); @@ -292,6 +293,60 @@ TEST(FP6, TestAsType16x1) }); } +__global__ void test_f6_convert_rne(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 32; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + ck::float32_t float32_in(1.0f); + ck::float32_t float32_out{}; + + auto f6x32_vec = f6_convert_rne(float32_in); + float32_out = type_convert(f6x32_vec); + + ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32_out[static_cast(ii)]; }); + i = N; +} + +TEST(MXFP6, DeviceF6ConvertRNE) +{ + constexpr int N = 32; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_f6_convert_rne<<<1, 1>>>(static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + EXPECT_EQ(N, completed); + ck::static_for<0, N, 1>{}( + [&](auto ii) { EXPECT_EQ(out[static_cast(ii)], 1.0f) << "ii: " << ii << std::endl; }); + + auto f6x32_vec_tc = ck::type_convert(ck::float32_t(1.0f)); + auto f6x32_vec_cnstr = f6x32_pk_t(0x08); + + EXPECT_EQ(f6x32_vec_tc, f6x32_vec_cnstr); +} + // test vector of 2 f6x16_pk_t, contains 32 f6_t TEST(FP6, TestAsType16x2) {