mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] layernorm support fused-quant/fused-add (#1604)
* add prenorm/postnorm support, refactor using generate.py * update README * update README * fix format * update some description and fix format * update format * format * use non-raw for loading * format and update n4096 * dynamic-quant ready * update readme * support fused dynamic-quant * update fused-quant, with smooth * update README * update args * update some based on comment
This commit is contained in:
104
include/ck_tile/core/numeric/int8.hpp
Normal file
104
include/ck_tile/core/numeric/int8.hpp
Normal file
@@ -0,0 +1,104 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// use int8_t directly for int8 arithemetic
|
||||
// here one can use ck_tile::int8_t to access original int8_t
|
||||
using int8_t = int8_t;
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<int8_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t min() { return int8_t(-128); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t lowest() { return int8_t(-128); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t max() { return int8_t(127); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t epsilon()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t round_error()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t infinity()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t quiet_NaN()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t signaling_NaN()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t denorm_min()
|
||||
{
|
||||
return 1; // not used
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr int8_t zero() { return 0; }
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <typename T>
|
||||
struct numeric_traits;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<int8_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr int bias = 15;
|
||||
static constexpr uint16_t nan_mask = 0x7C00;
|
||||
static constexpr uint16_t head_mask = 0xFC00;
|
||||
static constexpr uint16_t mant_mask = 0x3FF;
|
||||
static constexpr uint16_t exp_mask = 0x1F;
|
||||
static constexpr uint32_t Inf = 0x7C00;
|
||||
static constexpr uint32_t NegInf = 0xFC00;
|
||||
static constexpr uint32_t NaN = 0x7C01;
|
||||
static constexpr uint32_t Neg0 = 0x8000;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr float int8_to_float(const int8_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr int8_t float_to_int8(const float& x) { return static_cast<int8_t>(x); }
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -60,6 +61,9 @@ CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
|
||||
CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
|
||||
CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
|
||||
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
#endif
|
||||
|
||||
|
||||
Reference in New Issue
Block a user