mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
[CK] add composable kernel support on gfx1250 (#6978) ## Motivation Add composable kernel support on gfx1250. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Qun Lin <qlin@amd.com> Co-authored-by: jialuo12_amdeng <jia.luo@amd.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: hsivasun_amdeng <haresh.sivasuntharampillai@amd.com>
101 lines
2.7 KiB
C++
101 lines
2.7 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#ifndef CK_CODE_GEN_RTC
|
|
|
|
#include "ck/utility/scale_utils.hpp"
|
|
#include "ck/utility/type.hpp"
|
|
|
|
namespace ck {
|
|
|
|
struct e5m3_scale_t
|
|
{
|
|
using type = uint8_t;
|
|
using Format = utils::ScaleFormat<5, 3>;
|
|
|
|
static constexpr int exponent_bits = Format::exponent_bits;
|
|
static constexpr int mantissa_bits = Format::mantissa_bits;
|
|
static constexpr type value_mask = Format::value_mask;
|
|
static constexpr type nan_mask = Format::nan_mask;
|
|
static constexpr type max_finite = Format::max_finite;
|
|
static constexpr int bias = Format::bias;
|
|
|
|
type data;
|
|
|
|
__host__ __device__ constexpr e5m3_scale_t() : data{type{}} {}
|
|
__host__ __device__ constexpr explicit e5m3_scale_t(type init)
|
|
: data{static_cast<type>(init & value_mask)}
|
|
{
|
|
}
|
|
__host__ __device__ constexpr explicit e5m3_scale_t(int init)
|
|
: data{static_cast<type>(static_cast<type>(init) & value_mask)}
|
|
{
|
|
}
|
|
__host__ __device__ explicit e5m3_scale_t(float scale)
|
|
{
|
|
#if defined(__gfx1250__)
|
|
union
|
|
{
|
|
float fval;
|
|
uint32_t i32val;
|
|
uint8_t i8val[4];
|
|
} val;
|
|
val.fval = scale;
|
|
uint32_t ival = 0;
|
|
const float max_e5m3 = 114688.0f;
|
|
// if x is not +/- infinity or nan
|
|
if((val.i32val & 0x7F800000) != 0x7F800000)
|
|
// clip float value
|
|
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_e5m3, -max_e5m3);
|
|
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, true);
|
|
val.i32val = ival;
|
|
data = val.i8val[0];
|
|
#else
|
|
data = Format::encode(scale);
|
|
#endif
|
|
}
|
|
|
|
__host__ __device__ explicit operator float() const
|
|
{
|
|
#if defined(__gfx1250__)
|
|
union
|
|
{
|
|
unsigned int i32val;
|
|
uint8_t i8val[4];
|
|
} val;
|
|
val.i8val[0] = this->data;
|
|
return __builtin_amdgcn_cvt_f32_fp8(val.i32val, true);
|
|
#else
|
|
return Format::decode(data);
|
|
#endif
|
|
}
|
|
|
|
__host__ __device__ constexpr bool operator==(const e5m3_scale_t& other) const
|
|
{
|
|
return data == other.data && !is_nan();
|
|
}
|
|
|
|
__host__ __device__ constexpr bool operator!=(const e5m3_scale_t& other) const
|
|
{
|
|
return !(*this == other);
|
|
}
|
|
|
|
__host__ __device__ constexpr bool is_nan() const { return Format::is_nan(data); }
|
|
};
|
|
|
|
namespace utils {
|
|
|
|
template <>
|
|
__host__ __device__ inline constexpr int32_t get_exponent_value<e5m3_scale_t>(e5m3_scale_t x)
|
|
{
|
|
return e5m3_scale_t::Format::exponent(x.data);
|
|
}
|
|
|
|
} // namespace utils
|
|
|
|
} // namespace ck
|
|
|
|
#endif
|