mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Fix universal gemm profiler for pk_i4_t (#1790)
* Fix universal gemm profiler for pk_i4_t
* fix
[ROCm/composable_kernel commit: 888317e698]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
|
||||
else
|
||||
os << delim;
|
||||
|
||||
if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck::bf8_t>)
|
||||
using RangeType = ck::remove_cvref_t<decltype(v)>;
|
||||
if constexpr(std::is_same_v<RangeType, ck::f8_t> || std::is_same_v<RangeType, ck::bf8_t> ||
|
||||
std::is_same_v<RangeType, ck::bhalf_t>)
|
||||
{
|
||||
os << ck::type_convert<float>(v);
|
||||
}
|
||||
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t>)
|
||||
{
|
||||
const auto packed_floats = ck::type_convert<ck::float2_t>(v);
|
||||
const ck::vector_type<float, 2> vector_of_floats{packed_floats};
|
||||
os << vector_of_floats.template AsType<float>()[ck::Number<0>{}] << delim
|
||||
<< vector_of_floats.template AsType<float>()[ck::Number<1>{}];
|
||||
}
|
||||
else
|
||||
{
|
||||
os << static_cast<T>(v);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -465,6 +465,19 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
|
||||
{
|
||||
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
|
||||
uint8_t x_l = (x_u8 & 0x0f) >> 0;
|
||||
uint8_t x_h = (x_u8 & 0xf0) >> 4;
|
||||
|
||||
auto l_f32 = ck::type_convert<float>(x_l);
|
||||
auto h_f32 = ck::type_convert<float>(x_h);
|
||||
|
||||
return {l_f32, h_f32};
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user