mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '054f85ab7c0fa07a90968e834899ec415af8b713' into develop
This commit is contained in:
@@ -14,7 +14,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
|
||||
* Added support for Stream-K version of mixed fp8/bf16 GEMM
|
||||
* Added support for Multiple D GEMM
|
||||
* Added GEMM pipeline for microscaling (MX) FP8/FP4 data types
|
||||
* Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types
|
||||
* Added support for FP16 2:4 structured sparsity to universal GEMM.
|
||||
* Added support for Split K for grouped convolution backward data.
|
||||
* Added logit soft-capping support for fMHA forward kernels.
|
||||
|
||||
@@ -10,6 +10,9 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_bf8)
|
||||
# add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
|
||||
# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp6 gemm_mx_fp6.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp6)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp4)
|
||||
|
||||
@@ -55,3 +58,7 @@ set(FP8_MXGEMM_OPTIONS)
|
||||
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
|
||||
set(FP6_MXGEMM_OPTIONS)
|
||||
list(APPEND FP6_MXGEMM_OPTIONS -mavx512f)
|
||||
example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS})
|
||||
|
||||
@@ -245,6 +245,11 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
if(K % ck::packed_size_v<ADataType> != 0 || K % ck::packed_size_v<BDataType> != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of packed size.");
|
||||
};
|
||||
|
||||
// Hardcode scale layouts as per pipeline assumptions
|
||||
// TODO: Allow user to specify scale layouts
|
||||
using AScaleLayout = Row;
|
||||
@@ -292,12 +297,20 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
auto a_data_element = [](float x) {
|
||||
if constexpr(ck::is_same_v<ADataType, ck::f4x2_pk_t>)
|
||||
return ck::type_convert<ADataType>(ck::float2_t(x));
|
||||
else if constexpr(ck::packed_size_v<ADataType> == 32)
|
||||
return ck::type_convert<ADataType>(ck::float32_t(x));
|
||||
else if constexpr(ck::packed_size_v<ADataType> == 16)
|
||||
return ck::type_convert<ADataType>(ck::float16_t(x));
|
||||
else
|
||||
return ck::type_convert<ADataType>(x);
|
||||
};
|
||||
auto b_data_element = [](float x) {
|
||||
if constexpr(ck::is_same_v<BDataType, ck::f4x2_pk_t>)
|
||||
return ck::type_convert<BDataType>(ck::float2_t(x));
|
||||
else if constexpr(ck::packed_size_v<BDataType> == 32)
|
||||
return ck::type_convert<BDataType>(ck::float32_t(x));
|
||||
else if constexpr(ck::packed_size_v<BDataType> == 16)
|
||||
return ck::type_convert<BDataType>(ck::float16_t(x));
|
||||
else
|
||||
return ck::type_convert<BDataType>(x);
|
||||
};
|
||||
@@ -307,30 +320,35 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: // Initializations for development and debugging
|
||||
ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
|
||||
|
||||
ck::utils::FillConstant<ADataType>{a_data_element(0.5f)}(a_m_k);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
|
||||
|
||||
ck::utils::FillConstant<BDataType>{b_data_element(2.0f)}(*b_k_n);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(b_k_n_scale);
|
||||
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "Init A = {1}" << std::endl;
|
||||
std::cout << "Init A = {0.5}" << std::endl;
|
||||
std::cout << "Init A scale = {2.0}" << std::endl;
|
||||
std::cout << "Init B = {0.5}" << std::endl;
|
||||
std::cout << "Init B scale = {1.0}" << std::endl;
|
||||
std::cout << "Init B = {2.0}" << std::endl;
|
||||
std::cout << "Init B scale = {0.5}" << std::endl;
|
||||
std::cout << "Expect C = {K}" << std::endl;
|
||||
}
|
||||
break;
|
||||
|
||||
case 1:
|
||||
a_m_k.GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5]
|
||||
b_k_n->GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5]
|
||||
a_m_k.GenerateTensorDistr(
|
||||
int_distr{-5, 5}, ck::identity{}, std::minstd_rand(time(nullptr))); // Z[-5,5]
|
||||
b_k_n->GenerateTensorDistr(int_distr{-5, 5}); // Z[-5,5]
|
||||
static_assert(ck::is_same_v<XDataType, ck::e8m0_bexp_t>);
|
||||
a_m_k_scale.GenerateTensorDistr(int_distr{120, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
b_k_n_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
|
||||
b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
|
||||
break;
|
||||
|
||||
case 2:
|
||||
a_m_k.GenerateTensorDistr(float_distr{-2.0, 2.0});
|
||||
a_m_k.GenerateTensorDistr(
|
||||
float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr))); // R[-2,2]
|
||||
a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f});
|
||||
|
||||
b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0});
|
||||
|
||||
99
example/67_gemm_microscaling/gemm_mx_fp6.cpp
Normal file
99
example/67_gemm_microscaling/gemm_mx_fp6.cpp
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::f6x16_pk_t;
|
||||
using BDataType = ck::f6x16_pk_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / ck::packed_size_v<ADataType>; // K dimension size per block
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
ADataType, // ADataType
|
||||
XDataType, // AScaleDataType
|
||||
BDataType, // BDataType
|
||||
XDataType, // BScaleDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // GemmAccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Number of threads per block
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
1, // AK1 number of elements to read at a time when transferring from global memory to LDS
|
||||
1, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
BDataType // ComputeTypeB
|
||||
>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<DeviceOpInstance,
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ScaleBlockSize>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
@@ -556,6 +556,64 @@ struct Tensor
|
||||
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(
|
||||
ck::float2_t{ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_)))})};
|
||||
else if constexpr(ck::is_same_v<T, ck::f6x32_pk_t> ||
|
||||
ck::is_same_v<T, ck::bf6x32_pk_t>)
|
||||
{
|
||||
return ck::type_convert<T>(
|
||||
ck::float32_t{ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_)))});
|
||||
}
|
||||
else if constexpr(ck::is_same_v<T, ck::f6x16_pk_t> ||
|
||||
ck::is_same_v<T, ck::bf6x16_pk_t>)
|
||||
{
|
||||
return ck::type_convert<T>(
|
||||
ck::float16_t{ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_)))});
|
||||
}
|
||||
else
|
||||
static_assert(false, "Unsupported packed size for T");
|
||||
};
|
||||
|
||||
@@ -66,9 +66,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
static constexpr index_t AMmaKStride = KPack;
|
||||
static constexpr index_t BMmaKStride = KPack;
|
||||
|
||||
//> store rows/cols into thread registers in chunks of 16
|
||||
//> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
|
||||
static constexpr index_t KThreadChunk = 16 / sizeof(ComputeTypeA);
|
||||
// store rows/cols into thread registers in chunks of 16 for FP8
|
||||
// e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
|
||||
// or in chunks of 32 / APackedSize for FP6/FP4
|
||||
static constexpr index_t KThreadChunk = (APackedSize == 1) ? 16 : 32 / APackedSize;
|
||||
|
||||
static_assert(APackedSize == BPackedSize, "APackedSize must be equal to BPackedSize for now");
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
|
||||
@@ -54,6 +54,8 @@ namespace device {
|
||||
*
|
||||
* Conditions for achieving computational load balancing on different hardware platforms can vary.
|
||||
*
|
||||
* \tparam KPerBlock is the number of elements in K dimension that each block processes (multiply with packed_size_v to get the actual KPerBlock)
|
||||
*
|
||||
* Serialized version of the algorithm:
|
||||
* \code
|
||||
* // E = A * B + C
|
||||
@@ -117,7 +119,7 @@ template <typename ALayout,
|
||||
index_t BlockSize, // Thread block size
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t KPerBlock, // multiply with packed_size_v to get the actual KPerBlock
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
|
||||
@@ -419,6 +419,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
(GemmSpec != GemmSpecialization::Default &&
|
||||
GemmSpec != GemmSpecialization::MPadding)),
|
||||
"f4x2_pk_t does not support K padding");
|
||||
static_assert(!((is_same_v<remove_cvref_t<ADataType>, f6x16_pk_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, bf6x16_pk_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, f6x32_pk_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, bf6x32_pk_t>)&&GemmSpec !=
|
||||
GemmSpecialization::Default),
|
||||
"Packed F6 types do not support padding");
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
|
||||
@@ -889,7 +889,6 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
|
||||
const ScaleB& scale_b,
|
||||
FloatC& reg_c) const
|
||||
{
|
||||
|
||||
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
|
||||
a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
|
||||
}
|
||||
@@ -1224,6 +1223,27 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f6_t, 32, 32, f6_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<f6_t, 16, 16, f6_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<bf6_t, 32, 32, bf6_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<bf6_t, 16, 16, bf6_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
|
||||
{
|
||||
@@ -1405,8 +1425,7 @@ struct XdlopsGemm
|
||||
MPerXdlops == 64,
|
||||
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
|
||||
|
||||
static_assert(KPack * 2 % mfma_instr.k_per_blk == 0,
|
||||
"KPack should be a multiple of k_per_blk");
|
||||
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
|
||||
}
|
||||
|
||||
// XDL output supporting C = A * B
|
||||
|
||||
@@ -1037,6 +1037,54 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f6x16x2_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const f6x16x2_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
using arg_type = int32x8_t;
|
||||
arg_type arg_a{
|
||||
static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][0]),
|
||||
static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][1]),
|
||||
static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][2]),
|
||||
static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][0]),
|
||||
static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][1]),
|
||||
static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][2]),
|
||||
0,
|
||||
0};
|
||||
arg_type arg_b{
|
||||
static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][0]),
|
||||
static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][1]),
|
||||
static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][2]),
|
||||
static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][0]),
|
||||
static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][1]),
|
||||
static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][2]),
|
||||
0,
|
||||
0};
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_a,
|
||||
arg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
2, // blgp
|
||||
OpselA, // OPSEL
|
||||
scale_a,
|
||||
OpselB, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf6x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
|
||||
@@ -67,27 +67,42 @@ struct f6_pk_t
|
||||
{
|
||||
using element_type = uint32_t; // element storage fundamental type
|
||||
|
||||
static constexpr index_t packed_size = pk_size;
|
||||
static constexpr index_t num_bits_elem = 6;
|
||||
static constexpr index_t num_bits_vec_elem = sizeof(element_type) * CHAR_BIT;
|
||||
static constexpr index_t packed_size = pk_size; // 16 or 32 for now
|
||||
static constexpr index_t num_bits_elem = 6; // specialized for 6-bit data
|
||||
// XXX: CHAR_BIT is not defined in HIPRTC, so we must use 8
|
||||
static constexpr index_t num_bits_vec_elem =
|
||||
sizeof(element_type) * 8; // 32-bit uint for storage
|
||||
static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0,
|
||||
"Packed elements must fit exactly into the element storage.");
|
||||
static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem;
|
||||
static constexpr index_t vector_size =
|
||||
(packed_size * num_bits_elem) / num_bits_vec_elem; // 3 or 6 element_type units
|
||||
|
||||
using storage_type = StaticallyIndexedArray_v2<element_type, vector_size>;
|
||||
storage_type data; // packed data
|
||||
using storage_type = element_type __attribute__((ext_vector_type(vector_size)));
|
||||
storage_type data_{storage_type(0)}; // packed data
|
||||
|
||||
using type = f6_pk_t<BitType, packed_size>;
|
||||
|
||||
__host__ __device__ constexpr f6_pk_t() : data{} {}
|
||||
__host__ __device__ constexpr f6_pk_t(storage_type init) : data{init} {}
|
||||
__host__ __device__ constexpr f6_pk_t() {}
|
||||
__host__ __device__ constexpr f6_pk_t(const storage_type& init) : data_{init}
|
||||
{
|
||||
// TODO: consider removing initialization similar to vector_type<T, 256>
|
||||
}
|
||||
|
||||
// Initialize from a vector type with the same size as packed_size
|
||||
template <typename T, typename = enable_if_t<scalar_type<T>::vector_size == packed_size>>
|
||||
__host__ __device__ f6_pk_t(const T& v) : data{}
|
||||
__host__ __device__ f6_pk_t(const T& v)
|
||||
{
|
||||
static_for<0, packed_size, 1>{}(
|
||||
[&](auto i) { pack(v[static_cast<index_t>(i)], static_cast<index_t>(i)); });
|
||||
}
|
||||
|
||||
// Broadcast single initialization value to all packed elements
|
||||
__host__ __device__ f6_pk_t(const int8_t v)
|
||||
: f6_pk_t(static_cast<int8_t __attribute__((ext_vector_type(packed_size)))>(v))
|
||||
{
|
||||
// TODO: consider removing initialization similar to vector_type<T, 256>
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void pack(const T x, const index_t i)
|
||||
{
|
||||
@@ -99,18 +114,18 @@ struct f6_pk_t
|
||||
const int arr_index = bit_pos / num_bits_vec_elem;
|
||||
const int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
uint32_t old_value = data.data_[arr_index];
|
||||
uint32_t old_value = data_[arr_index];
|
||||
|
||||
// insert bits into the current 32-bit block
|
||||
old_value |= (bits << bit_offset);
|
||||
data.data_[arr_index] = old_value;
|
||||
data_[arr_index] = old_value;
|
||||
|
||||
// if it crosses into the next block, shift the remainder
|
||||
if(overhang > 0 && (arr_index + 1) < vector_size)
|
||||
{
|
||||
uint32_t next_value = data.data_[arr_index + 1];
|
||||
uint32_t next_value = data_[arr_index + 1];
|
||||
next_value |= (bits >> (num_bits_elem - overhang));
|
||||
data.data_[arr_index + 1] = next_value;
|
||||
data_[arr_index + 1] = next_value;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,17 +136,33 @@ struct f6_pk_t
|
||||
const int bit_offset = bit_pos % num_bits_vec_elem;
|
||||
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
|
||||
|
||||
uint32_t bits = pk.data.data_[arr_idx] >> bit_offset;
|
||||
uint32_t bits = pk.data_[arr_idx] >> bit_offset;
|
||||
if(overhang > 0 && (arr_idx + 1) < vector_size)
|
||||
{
|
||||
bits |= (pk.data.data_[arr_idx + 1] & ((1u << overhang) - 1))
|
||||
<< (num_bits_elem - overhang);
|
||||
bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang);
|
||||
}
|
||||
|
||||
return static_cast<BitType>(bits & 0x3F);
|
||||
}
|
||||
|
||||
__host__ __device__ inline BitType unpack(const index_t i) const { return unpack(*this, i); }
|
||||
|
||||
// Compare operator
|
||||
__host__ __device__ friend bool operator==(const f6_pk_t& lhs, const f6_pk_t& rhs)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < vector_size; ++i)
|
||||
{
|
||||
if(lhs.data_[i] != rhs.data_[i])
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ friend bool operator!=(const f6_pk_t& lhs, const f6_pk_t& rhs)
|
||||
{
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
};
|
||||
|
||||
using f6x16_pk_t = f6_pk_t<f6_t, 16>;
|
||||
@@ -296,6 +327,34 @@ struct scalar_type<f4x2_pk_t>
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<f6x32_pk_t>
|
||||
{
|
||||
using type = f6x32_pk_t::storage_type;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bf6x32_pk_t>
|
||||
{
|
||||
using type = bf6x32_pk_t::storage_type;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<f6x16_pk_t>
|
||||
{
|
||||
using type = f6x16_pk_t::storage_type;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bf6x16_pk_t>
|
||||
{
|
||||
using type = bf6x16_pk_t::storage_type;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bool>
|
||||
{
|
||||
|
||||
@@ -1438,14 +1438,16 @@ struct non_native_vector_base<
|
||||
|
||||
// implementation for f6x16 and f6x32
|
||||
template <typename T, index_t N>
|
||||
struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>
|
||||
struct non_native_vector_base<
|
||||
T,
|
||||
N,
|
||||
ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 16 || sizeof(T) == 24 || sizeof(T) == 32>>
|
||||
{
|
||||
using data_t =
|
||||
typename nnvb_data_t_selector<T>::type; // select data_t based on declared base type
|
||||
using element_t = typename T::element_type; // select element_t based on declared element type
|
||||
static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch");
|
||||
static constexpr size_t size_factor =
|
||||
sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6
|
||||
static constexpr size_t size_factor = sizeof(data_t) / sizeof(element_t);
|
||||
using data_v = element_t __attribute__((ext_vector_type(N * size_factor)));
|
||||
using type = non_native_vector_base<T, N>;
|
||||
|
||||
@@ -1457,29 +1459,29 @@ struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T)
|
||||
StaticallyIndexedArray<data_v, 1> dNx1;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr non_native_vector_base(data_t a)
|
||||
: data_{data_v(a.At(Number<0>{}))}
|
||||
// Broadcast single value to vector
|
||||
__host__ __device__ constexpr non_native_vector_base(data_t a) : data_{}
|
||||
{
|
||||
// TODO: consider removing initialization similar to vector_type<T, 256>
|
||||
|
||||
ck::static_for<0, N, 1>{}([&](auto i) {
|
||||
data_.dxN(i) = a; // broadcast value to all elements
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr non_native_vector_base(T f)
|
||||
: non_native_vector_base(bit_cast<data_t>(f))
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){};
|
||||
|
||||
__host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {}
|
||||
|
||||
__host__ __device__ constexpr non_native_vector_base(element_t v) : data_{data_v(v)} {}
|
||||
|
||||
__host__ __device__ constexpr operator data_v() const { return data_.dN; }
|
||||
__host__ __device__ constexpr operator data_t() const
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return data_.dxN[Number<0>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
return data_.dxN; // XXX this should cause an error
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr operator T() const
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
@@ -1488,7 +1490,31 @@ struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T)
|
||||
}
|
||||
else
|
||||
{
|
||||
return data_.dTxN; // XXX this should cause an error
|
||||
return err; // XXX this should cause an error
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same_v<X, data_t> || is_same_v<X, data_v> || is_same_v<X, T>,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same_v<X, data_v>)
|
||||
{
|
||||
return data_.dNx1;
|
||||
}
|
||||
else if constexpr(is_same_v<X, data_t>)
|
||||
{
|
||||
return data_.dxN;
|
||||
}
|
||||
else if constexpr(is_same_v<X, T>)
|
||||
{
|
||||
return data_.dTxN;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -1504,8 +1530,10 @@ struct scalar_type<non_native_vector_base<
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<
|
||||
non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>>
|
||||
struct scalar_type<non_native_vector_base<
|
||||
T,
|
||||
N,
|
||||
ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 16 || sizeof(T) == 24 || sizeof(T) == 32>>>
|
||||
{
|
||||
using type = typename non_native_vector_base<T, N>::element_t;
|
||||
static constexpr index_t vector_size = N * non_native_vector_base<T, N>::size_factor;
|
||||
@@ -2221,8 +2249,9 @@ using f4x32_t = typename vector_type<f4x2_pk_t, 16>::type;
|
||||
using f4x64_t = typename vector_type<f4x2_pk_t, 32>::type;
|
||||
|
||||
// f6
|
||||
using f6x16_t = typename vector_type<f6x16_pk_t, 1>::type;
|
||||
using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
|
||||
using f6x16_t = typename vector_type<f6x16_pk_t, 1>::type;
|
||||
using f6x16x2_t = typename vector_type<f6x16_pk_t, 2>::type;
|
||||
using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
|
||||
|
||||
// bf6
|
||||
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
|
||||
|
||||
@@ -34,6 +34,10 @@ struct DynamicBuffer
|
||||
ElementSpaceSize element_space_size_;
|
||||
T invalid_element_value_ = T{0};
|
||||
|
||||
// XXX: PackedSize semantics for pk_i4_t is different from the other packed types.
|
||||
// Objects of f4x2_pk_t and f6_pk_t are counted as 1 element, while
|
||||
// objects of pk_i4_t are counted as 2 elements. Therefore, element_space_size_ for pk_i4_t must
|
||||
// be divided by 2 to correctly represent the number of addressable elements.
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
|
||||
return 2;
|
||||
|
||||
@@ -501,8 +501,8 @@ inline __host__ __device__ float scaled_type_convert<float, f6_t>(e8m0_bexp_t sc
|
||||
float float_array[32];
|
||||
} out{};
|
||||
|
||||
out.float_vector =
|
||||
__builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(in.f6_vector, type_convert<float>(scale));
|
||||
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
|
||||
in.f6_vector.template AsType<f6x32_t::data_t>()[Number<0>{}], type_convert<float>(scale));
|
||||
return out.float_array[0];
|
||||
#else
|
||||
return utils::to_float<f6_t>(scale, x);
|
||||
@@ -522,7 +522,8 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f6x32_t>(e8m
|
||||
f6x32_t x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(x, type_convert<float>(scale));
|
||||
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
|
||||
x.template AsType<f6x32_t::data_t>()[Number<0>{}], type_convert<float>(scale));
|
||||
#else
|
||||
union
|
||||
{
|
||||
@@ -567,8 +568,8 @@ inline __host__ __device__ float scaled_type_convert<float, bf6_t>(e8m0_bexp_t s
|
||||
float float_array[32];
|
||||
} out{};
|
||||
|
||||
out.float_vector =
|
||||
__builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(in.bf6_vector, type_convert<float>(scale));
|
||||
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
|
||||
in.bf6_vector.template AsType<bf6x32_t::data_t>()[Number<0>{}], type_convert<float>(scale));
|
||||
return out.float_array[0];
|
||||
#else
|
||||
return utils::to_float<bf6_t>(scale, x);
|
||||
@@ -588,7 +589,8 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, bf6x32_t>(e8
|
||||
bf6x32_t x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(x, type_convert<float>(scale));
|
||||
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
|
||||
x.template AsType<bf6x32_t::data_t>()[Number<0>{}], type_convert<float>(scale));
|
||||
#else
|
||||
union
|
||||
{
|
||||
|
||||
@@ -1734,7 +1734,7 @@ inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f)
|
||||
f6_t f6_array[32];
|
||||
} out{};
|
||||
|
||||
out.f6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale);
|
||||
out.f6_vector = f6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale)};
|
||||
|
||||
return out.f6_array[0];
|
||||
#else
|
||||
@@ -1757,7 +1757,7 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0
|
||||
#if defined(__gfx950__)
|
||||
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
|
||||
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
|
||||
return __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale);
|
||||
return f6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale)};
|
||||
#else
|
||||
union
|
||||
{
|
||||
@@ -1765,17 +1765,15 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0
|
||||
float float_array[32];
|
||||
} in{x};
|
||||
|
||||
union
|
||||
{
|
||||
f6x32_t f6_vector;
|
||||
f6_t f6_array[32];
|
||||
} out{};
|
||||
using array_type = uint8_t __attribute__((ext_vector_type(32)));
|
||||
array_type uint8_array;
|
||||
|
||||
// collect the 6-bit values into an array
|
||||
ck::static_for<0, 32, 1>{}([&](auto i) {
|
||||
out.f6_array[i] = utils::sat_convert_to_type<f6_t>(in.float_array[i] / scale);
|
||||
uint8_array[static_cast<index_t>(i)] =
|
||||
utils::sat_convert_to_type<f6_t>(in.float_array[i] / scale);
|
||||
});
|
||||
|
||||
return out.f6_vector;
|
||||
return f6x32_t{f6x32_pk_t{uint8_array}};
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -1807,7 +1805,8 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
|
||||
f6_t f6_array[32];
|
||||
} out{};
|
||||
|
||||
out.f6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale);
|
||||
out.f6_vector =
|
||||
f6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale)};
|
||||
|
||||
return out.f6_array[0];
|
||||
#else
|
||||
@@ -1837,7 +1836,7 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
|
||||
return f6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale)};
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
union
|
||||
@@ -1852,6 +1851,7 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f
|
||||
uint32_t rng =
|
||||
prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
|
||||
#endif
|
||||
|
||||
union
|
||||
{
|
||||
float32_t float_vector;
|
||||
@@ -1914,6 +1914,43 @@ inline __host__ __device__ f6x32_t type_convert<f6x32_t, float32_t>(float32_t x)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ f6x32_pk_t type_convert<f6x32_pk_t, float32_t>(float32_t x)
|
||||
{
|
||||
return static_cast<f6x32_pk_t>(type_convert<f6x32_t>(x));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ f6x16_t type_convert<f6x16_t, float16_t>(float16_t x)
|
||||
{
|
||||
|
||||
union
|
||||
{
|
||||
float16_t v16x2[2];
|
||||
float32_t v32;
|
||||
} in{{x, x}};
|
||||
|
||||
union
|
||||
{
|
||||
f6x32_t v32;
|
||||
f6x16_t v16x2[2];
|
||||
} out{};
|
||||
|
||||
#if CK_USE_SR_F6_CONVERSION
|
||||
out.v32 = f6_convert_sr(in.v32);
|
||||
#else
|
||||
out.v32 = f6_convert_rne(in.v32);
|
||||
#endif
|
||||
|
||||
return out.v16x2[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ f6x16_pk_t type_convert<f6x16_pk_t, float16_t>(float16_t x)
|
||||
{
|
||||
return static_cast<f6x16_pk_t>(type_convert<f6x16_t>(x));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Specializes the type conversion template for converting the 6-bit float type (f6_t) to
|
||||
* float.
|
||||
@@ -1929,9 +1966,9 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
f6x32_t f6_vector;
|
||||
f6_t f6_array[32];
|
||||
} in{x};
|
||||
f6x32_t f6_vector;
|
||||
} in{{x}};
|
||||
|
||||
union
|
||||
{
|
||||
@@ -1940,7 +1977,8 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
|
||||
} out{};
|
||||
|
||||
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
|
||||
in.f6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
|
||||
in.f6_vector.template AsType<f6x32_t::data_t>()[Number<0>{}],
|
||||
type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
|
||||
return out.float_array[0];
|
||||
#else
|
||||
return utils::to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
|
||||
@@ -1948,8 +1986,8 @@ inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float types
|
||||
* (f6x32_t) to vector of 32 floats.
|
||||
* @brief Specializes the type conversion template for converting the vector of 32 6-bit float
|
||||
* types (f6x32_t) to vector of 32 floats.
|
||||
*
|
||||
* Interprets an f6_t values as floats using the default scale factor of 1.
|
||||
*
|
||||
@@ -1961,7 +1999,8 @@ inline __host__ __device__ float32_t type_convert<float32_t, f6x32_t>(f6x32_t x)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
|
||||
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
|
||||
x.template AsType<f6x32_t::data_t>()[Number<0>{}],
|
||||
type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
|
||||
#else
|
||||
union
|
||||
{
|
||||
@@ -1984,6 +2023,31 @@ inline __host__ __device__ float32_t type_convert<float32_t, f6x32_t>(f6x32_t x)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ float16_t type_convert<float16_t, f6x16_t>(f6x16_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
f6x16_t v16x2[2];
|
||||
f6x32_t v32;
|
||||
} in{{x, x}};
|
||||
|
||||
union
|
||||
{
|
||||
float16_t v16x2[2];
|
||||
float32_t v32;
|
||||
} out{};
|
||||
|
||||
out.v32 = type_convert<float32_t>(in.v32);
|
||||
return out.v16x2[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ float16_t type_convert<float16_t, f6x16_pk_t>(f6x16_pk_t x)
|
||||
{
|
||||
return type_convert<float16_t>(static_cast<f6x16_t>(x));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts a float to the 6-bit BF6 type using round-to-nearest-even.
|
||||
*
|
||||
@@ -2006,7 +2070,7 @@ inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f)
|
||||
bf6_t bf6_array[32];
|
||||
} out{};
|
||||
|
||||
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale);
|
||||
out.bf6_vector = bf6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale)};
|
||||
|
||||
return out.bf6_array[0];
|
||||
#else
|
||||
@@ -2030,7 +2094,7 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1
|
||||
#if defined(__gfx950__)
|
||||
float16_t* in1 = reinterpret_cast<float16_t*>(&x);
|
||||
float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
|
||||
return __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale);
|
||||
return bf6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale)};
|
||||
#else
|
||||
union
|
||||
{
|
||||
@@ -2081,7 +2145,8 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
|
||||
bf6_t bf6_array[32];
|
||||
} out{};
|
||||
|
||||
out.bf6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale);
|
||||
out.bf6_vector =
|
||||
bf6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale)};
|
||||
|
||||
return out.bf6_array[0];
|
||||
#else
|
||||
@@ -2113,7 +2178,7 @@ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.
|
||||
// use HW clock for stochastic input multiply by incremented thread id
|
||||
uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
|
||||
(get_thread_global_1d_id() + 1));
|
||||
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
|
||||
return bf6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale)};
|
||||
#else
|
||||
constexpr int seed = 1254739;
|
||||
union
|
||||
@@ -2186,6 +2251,12 @@ inline __host__ __device__ bf6x32_t type_convert<bf6x32_t, float32_t>(float32_t
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ bf6x32_pk_t type_convert<bf6x32_pk_t, float32_t>(float32_t x)
|
||||
{
|
||||
return static_cast<bf6x32_pk_t>(type_convert<bf6x32_t>(x));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Specializes the type conversion template for converting a bf6_t value to float.
|
||||
*
|
||||
@@ -2201,9 +2272,9 @@ inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
bf6x32_t bf6_vector;
|
||||
bf6_t bf6_array[32];
|
||||
} in{x};
|
||||
bf6x32_t bf6_vector;
|
||||
} in{{x}};
|
||||
|
||||
union
|
||||
{
|
||||
@@ -2212,7 +2283,8 @@ inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
|
||||
} out{};
|
||||
|
||||
out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
|
||||
in.bf6_vector, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
|
||||
in.bf6_vector.template AsType<bf6x32_t::data_t>()[Number<0>{}],
|
||||
type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
|
||||
return out.float_array[0];
|
||||
#else
|
||||
return utils::to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), x);
|
||||
@@ -2234,7 +2306,8 @@ inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
|
||||
x, type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
|
||||
x.template AsType<bf6x32_t::data_t>()[Number<0>{}],
|
||||
type_convert<float>(NumericLimits<e8m0_bexp_t>::Binary_1()));
|
||||
#else
|
||||
union
|
||||
{
|
||||
|
||||
@@ -53,6 +53,7 @@ if(GPU_TARGETS MATCHES "gfx950")
|
||||
|
||||
add_gtest_executable(test_fp6 test_fp6.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_compile_options(test_fp6 PRIVATE -mavx512f)
|
||||
target_link_libraries(test_fp6 PRIVATE utility)
|
||||
endif()
|
||||
add_dependencies(test_mx_data_types test_fp6)
|
||||
|
||||
@@ -228,8 +228,8 @@ TEST(BF6, ScaledConvertFP32Stochastic)
|
||||
TEST(BF6, TestSize)
|
||||
{
|
||||
ASSERT_EQ(1, sizeof(bf6_t));
|
||||
ASSERT_EQ(12, sizeof(bf6x16_pk_t));
|
||||
ASSERT_EQ(24, sizeof(bf6x32_pk_t));
|
||||
ASSERT_EQ(16, sizeof(bf6x16_pk_t));
|
||||
ASSERT_EQ(32, sizeof(bf6x32_pk_t));
|
||||
ASSERT_EQ(16, sizeof(vector_type<bf6x16_pk_t, 1>));
|
||||
ASSERT_EQ(32, sizeof(vector_type<bf6x16_pk_t, 2>));
|
||||
ASSERT_EQ(32, sizeof(vector_type<bf6x32_pk_t, 1>));
|
||||
@@ -238,8 +238,8 @@ TEST(BF6, TestSize)
|
||||
TEST(BF6, TestAlignment)
|
||||
{
|
||||
ASSERT_EQ(1, alignof(bf6_t));
|
||||
ASSERT_EQ(4, alignof(bf6x16_pk_t));
|
||||
ASSERT_EQ(4, alignof(bf6x32_pk_t));
|
||||
ASSERT_EQ(16, alignof(bf6x16_pk_t));
|
||||
ASSERT_EQ(32, alignof(bf6x32_pk_t));
|
||||
ASSERT_EQ(16, alignof(vector_type<bf6x16_pk_t, 1>));
|
||||
ASSERT_EQ(32, alignof(vector_type<bf6x16_pk_t, 2>));
|
||||
ASSERT_EQ(32, alignof(vector_type<bf6x32_pk_t, 1>));
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/utility/scaled_type_convert.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
|
||||
using ck::e8m0_bexp_t;
|
||||
using ck::f6_convert_rne;
|
||||
@@ -227,8 +228,8 @@ TEST(FP6, ScaledConvertFP32Stochastic)
|
||||
TEST(FP6, TestSize)
|
||||
{
|
||||
ASSERT_EQ(1, sizeof(f6_t));
|
||||
ASSERT_EQ(12, sizeof(f6x16_pk_t));
|
||||
ASSERT_EQ(24, sizeof(f6x32_pk_t));
|
||||
ASSERT_EQ(16, sizeof(f6x16_pk_t));
|
||||
ASSERT_EQ(32, sizeof(f6x32_pk_t));
|
||||
ASSERT_EQ(16, sizeof(vector_type<f6x16_pk_t, 1>));
|
||||
ASSERT_EQ(32, sizeof(vector_type<f6x16_pk_t, 2>));
|
||||
ASSERT_EQ(32, sizeof(vector_type<f6x32_pk_t, 1>));
|
||||
@@ -237,8 +238,8 @@ TEST(FP6, TestSize)
|
||||
TEST(FP6, TestAlignment)
|
||||
{
|
||||
ASSERT_EQ(1, alignof(f6_t));
|
||||
ASSERT_EQ(4, alignof(f6x16_pk_t));
|
||||
ASSERT_EQ(4, alignof(f6x32_pk_t));
|
||||
ASSERT_EQ(16, alignof(f6x16_pk_t));
|
||||
ASSERT_EQ(32, alignof(f6x32_pk_t));
|
||||
ASSERT_EQ(16, alignof(vector_type<f6x16_pk_t, 1>));
|
||||
ASSERT_EQ(32, alignof(vector_type<f6x16_pk_t, 2>));
|
||||
ASSERT_EQ(32, alignof(vector_type<f6x32_pk_t, 1>));
|
||||
@@ -292,6 +293,60 @@ TEST(FP6, TestAsType16x1)
|
||||
});
|
||||
}
|
||||
|
||||
__global__ void test_f6_convert_rne(float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
constexpr int N = 32;
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
ck::float32_t float32_in(1.0f);
|
||||
ck::float32_t float32_out{};
|
||||
|
||||
auto f6x32_vec = f6_convert_rne(float32_in);
|
||||
float32_out = type_convert<ck::float32_t>(f6x32_vec);
|
||||
|
||||
ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32_out[static_cast<int>(ii)]; });
|
||||
i = N;
|
||||
}
|
||||
|
||||
TEST(MXFP6, DeviceF6ConvertRNE)
|
||||
{
|
||||
constexpr int N = 32;
|
||||
std::vector<float> out(N, -1.0f);
|
||||
|
||||
DeviceMem device_out(N * sizeof(float));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
test_f6_convert_rne<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
EXPECT_EQ(N, completed);
|
||||
ck::static_for<0, N, 1>{}(
|
||||
[&](auto ii) { EXPECT_EQ(out[static_cast<int>(ii)], 1.0f) << "ii: " << ii << std::endl; });
|
||||
|
||||
auto f6x32_vec_tc = ck::type_convert<f6x32_pk_t>(ck::float32_t(1.0f));
|
||||
auto f6x32_vec_cnstr = f6x32_pk_t(0x08);
|
||||
|
||||
EXPECT_EQ(f6x32_vec_tc, f6x32_vec_cnstr);
|
||||
}
|
||||
|
||||
// test vector of 2 f6x16_pk_t, contains 32 f6_t
|
||||
TEST(FP6, TestAsType16x2)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user