Revert "feature:tf32:add initial conv3d fwd kernel support (#2763)" (#2848)

This reverts commit c51102144f.
This commit is contained in:
Illia Silin
2025-09-15 08:27:04 -07:00
committed by GitHub
parent c51102144f
commit 03b59f8c76
44 changed files with 175 additions and 1085 deletions

View File

@@ -1636,45 +1636,4 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
}
};
/******************* tf32 *************************************/
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x8xf32;
template <>
struct intrin_mfma_f32_16x16x8xf32<16, 16>
{
template <class FloatC>
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x4xf32;
template <>
struct intrin_mfma_f32_32x32x4xf32<32, 32>
{
template <class FloatC>
__device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
} // namespace ck

View File

@@ -26,7 +26,6 @@ using byte = unsigned char;
using std::byte;
#endif
using tf32_t = _BitInt(19); // 1 sign bit, 8 exponent bits, 10 mantissa bits
using bhalf_t = ushort;
using half_t = _Float16;
using int4_t = _BitInt(4);
@@ -462,38 +461,4 @@ using int64_t = long long;
using int64_t = long;
#endif
template <typename T>
inline const char* get_type_name()
{
if constexpr(is_same_v<T, half_t>)
return "fp16";
else if constexpr(is_same_v<T, bhalf_t>)
return "bf16";
else if constexpr(is_same_v<T, tf32_t>)
return "tf32";
else if constexpr(is_same_v<T, int4_t>)
return "int4";
else if constexpr(is_same_v<T, f4_t>)
return "f4";
else if constexpr(is_same_v<T, f6_t>)
return "f6";
else if constexpr(is_same_v<T, bf6_t>)
return "bf6";
else if constexpr(is_same_v<T, f8_t>)
return "f8";
else if constexpr(is_same_v<T, bf8_t>)
return "bf8";
else if constexpr(is_same_v<T, e8m0_bexp_t>)
return "e8m0";
else if constexpr(is_same_v<T, float>)
return "fp32";
#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)
else
return "unknown";
#else
else
return typeid(T).name();
#endif
}
} // namespace ck

View File

@@ -187,19 +187,6 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int
return bf8_ocp_t{type_convert<bf8_ocp_t::data_type>(x)};
}
template <typename Y, enable_if_t<is_same_v<Y, ck::tf32_t>, bool> = false>
inline __host__ __device__ constexpr float type_convert(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
u.int32 = u.int32 & 0xffffe000;
return u.fp32;
}
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x)