mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
This reverts commit c51102144f.
This commit is contained in:
@@ -1636,45 +1636,4 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
|
||||
}
|
||||
};
|
||||
|
||||
/******************* tf32 *************************************/
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x8xf32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x8xf32<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x4xf32;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x4xf32<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -26,7 +26,6 @@ using byte = unsigned char;
|
||||
using std::byte;
|
||||
#endif
|
||||
|
||||
using tf32_t = _BitInt(19); // 1 sign bit, 8 exponent bits, 10 mantissa bits
|
||||
using bhalf_t = ushort;
|
||||
using half_t = _Float16;
|
||||
using int4_t = _BitInt(4);
|
||||
@@ -462,38 +461,4 @@ using int64_t = long long;
|
||||
using int64_t = long;
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
inline const char* get_type_name()
|
||||
{
|
||||
if constexpr(is_same_v<T, half_t>)
|
||||
return "fp16";
|
||||
else if constexpr(is_same_v<T, bhalf_t>)
|
||||
return "bf16";
|
||||
else if constexpr(is_same_v<T, tf32_t>)
|
||||
return "tf32";
|
||||
else if constexpr(is_same_v<T, int4_t>)
|
||||
return "int4";
|
||||
else if constexpr(is_same_v<T, f4_t>)
|
||||
return "f4";
|
||||
else if constexpr(is_same_v<T, f6_t>)
|
||||
return "f6";
|
||||
else if constexpr(is_same_v<T, bf6_t>)
|
||||
return "bf6";
|
||||
else if constexpr(is_same_v<T, f8_t>)
|
||||
return "f8";
|
||||
else if constexpr(is_same_v<T, bf8_t>)
|
||||
return "bf8";
|
||||
else if constexpr(is_same_v<T, e8m0_bexp_t>)
|
||||
return "e8m0";
|
||||
else if constexpr(is_same_v<T, float>)
|
||||
return "fp32";
|
||||
#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)
|
||||
else
|
||||
return "unknown";
|
||||
#else
|
||||
else
|
||||
return typeid(T).name();
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -187,19 +187,6 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int
|
||||
return bf8_ocp_t{type_convert<bf8_ocp_t::data_type>(x)};
|
||||
}
|
||||
|
||||
template <typename Y, enable_if_t<is_same_v<Y, ck::tf32_t>, bool> = false>
|
||||
inline __host__ __device__ constexpr float type_convert(float x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {x};
|
||||
|
||||
u.int32 = u.int32 & 0xffffe000;
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
// Convert X to Y
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y type_convert_sp(X x)
|
||||
|
||||
Reference in New Issue
Block a user