Merge commit '8a0d659f92897e1ae99e4dc0ea4842a2c78170ab' into develop

This commit is contained in:
github-actions[bot]
2025-05-06 15:07:46 +00:00
parent bf90418b06
commit b96328e63f
8 changed files with 610 additions and 79 deletions

View File

@@ -51,7 +51,8 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
{
os << ck::type_convert<float>(v);
}
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t>)
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t> ||
std::is_same_v<RangeType, ck::f4x2_pk_t>)
{
const auto packed_floats = ck::type_convert<ck::float2_t>(v);
const ck::vector_type<float, 2> vector_of_floats{packed_floats};
@@ -359,7 +360,8 @@ struct Tensor
std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return (mDesc.GetElementSpaceSize() + 1) / 2;
}
@@ -514,7 +516,8 @@ struct Tensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
@@ -527,7 +530,8 @@ struct Tensor
template <typename... Is>
T& operator()(Is... is)
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
@@ -540,7 +544,8 @@ struct Tensor
template <typename... Is>
const T& operator()(Is... is) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
@@ -552,7 +557,8 @@ struct Tensor
T& operator()(std::vector<std::size_t> idx)
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
@@ -564,7 +570,8 @@ struct Tensor
const T& operator()(std::vector<std::size_t> idx) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -81,6 +81,18 @@ struct GeneratorTensor_1<ck::f4_t>
}
};
template <>
struct GeneratorTensor_1<ck::f4x2_pk_t>
{
float value = 1.0;
template <typename... Is>
ck::f4x2_pk_t operator()(Is...)
{
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{value, value})};
}
};
template <>
struct GeneratorTensor_1<int8_t>
{
@@ -209,6 +221,21 @@ struct GeneratorTensor_2<ck::f4_t>
}
};
template <>
struct GeneratorTensor_2<ck::f4x2_pk_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::f4x2_pk_t operator()(Is...)
{
float tmp0 = (std::rand() % (max_value - min_value)) + min_value;
float tmp1 = (std::rand() % (max_value - min_value)) + min_value;
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{tmp0, tmp1})};
}
};
template <typename T>
struct GeneratorTensor_3
{
@@ -296,6 +323,25 @@ struct GeneratorTensor_3<ck::f4_t>
}
};
template <>
struct GeneratorTensor_3<ck::f4x2_pk_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::f4x2_pk_t operator()(Is...)
{
float tmp0 = float(std::rand()) / float(RAND_MAX);
float tmp1 = float(std::rand()) / float(RAND_MAX);
float fp32_tmp0 = min_value + tmp0 * (max_value - min_value);
float fp32_tmp1 = min_value + tmp1 * (max_value - min_value);
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{fp32_tmp0, fp32_tmp1})};
}
};
template <typename T>
struct GeneratorTensor_4
{

View File

@@ -508,6 +508,34 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
4, // cbsz
4, // blgp
0, // OPSEL
0,
0, // OPSEL
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
@@ -589,6 +617,40 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a,
const int32_t scale_a,
const f4x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
4, // cbsz
4, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
};
@@ -686,6 +748,39 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a,
const int32_t scale_a,
const f4x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
4, // cbsz
4, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a,
const int32_t& scale_a,
@@ -748,6 +843,33 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
4, // cbsz
4, // blgp
0, // OPSEL
0,
0, // OPSEL
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};

View File

@@ -470,6 +470,13 @@ struct scalar_type<e8m0_bexp_t>
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<f4x2_pk_t>
{
using type = f4x2_pk_t::type;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bool>
{