mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
* Add basic fp8 definitions and prn-generator
* Format
* Add fp8<->fp32 type_convert
* Format
* Split type_convert and cast_to/from_f8
* Format
* Minor fix
* Minor fix
* Move fp8 utils to a separate header
* Add elementwise ops
* Add fp8_convert_sr
* Format
* Add element op
* Eliminate magic numbers
* Split f8_convert_sr in host and device
* Format
* Add some constexpr
* Add a datatype test
* Format
* Another format
* Add fp8<->fp16 tests
* Update type_converts
* Format
* Add fp16 casting functions
* Format
* Use seed as a runtime arg
* Use element location for PRNG
* Format
* Add fp8<->fp16 to PassThrough element op
* Clean up
* Merge host and device implementations
* Add comments on rounding modes
* Remove leftover code
* Put type_converts into a separate header
* Put random number gen to a separate header
* Rearrange f8_utils' namespaces
* Refactor type_convert.hpp
* Move f8_t definition
[ROCm/composable_kernel commit: f0c620c42e]
54 lines
1.9 KiB
C++
54 lines
1.9 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
namespace ck {
|
|
|
|
// Pseudo random number generator
|
|
// version for fp32
|
|
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false>
|
|
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
|
|
{
|
|
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
|
|
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
|
drop_bits ^= x >> 16;
|
|
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
|
drop_bits *= 0x7000149;
|
|
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
|
|
// So, it can have an effect of using same id for multiple elements when the id is very
|
|
// large!
|
|
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
|
|
return rng;
|
|
}
|
|
|
|
// version for fp16
|
|
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
|
|
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
|
|
{
|
|
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
|
|
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
|
|
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
|
|
drop_bits *= 0x7000149;
|
|
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
|
|
// So, it can have an effect of using same id for multiple elements when the id is very
|
|
// large!
|
|
uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed);
|
|
return rng;
|
|
}
|
|
|
|
// return 0 if data is not fp16 or fp32
|
|
template <typename T,
|
|
uint32_t seed_t,
|
|
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
|
|
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
|
|
{
|
|
std::ignore = id;
|
|
std::ignore = val;
|
|
std::ignore = seed;
|
|
|
|
return 0;
|
|
}
|
|
|
|
} // namespace ck
|