mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
fix compile error, fp8 not ready now
This commit is contained in:
@@ -54,4 +54,3 @@
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
|
||||
@@ -267,7 +267,15 @@ struct numeric<bfloat16_t>
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return float_to_bf16(0.5f); }
|
||||
// maximum rounding error
|
||||
// bin : f edcba 9876543210
|
||||
// bits: s eeeeeeee mmmmmmm
|
||||
// 0 01111110 0000000 (0.5)
|
||||
//
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()
|
||||
|
||||
@@ -42,13 +42,13 @@ enum class fp8_rounding_mode
|
||||
*/
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_fp8_raw(float, constant<rounding> = {});
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_bf8_raw(float, constant<rounding> = {});
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(uint8_t);
|
||||
CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(uint8_t);
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t);
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t);
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
struct alignas(1) float8_e4m3_t
|
||||
@@ -581,7 +581,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
|
||||
|
||||
// clang-format off
|
||||
template<fp8_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE constexpr fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
|
||||
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
|
||||
{
|
||||
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
|
||||
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
|
||||
@@ -589,14 +589,14 @@ CK_TILE_HOST_DEVICE constexpr fp8_raw_t float_to_fp8_raw(float x, constant<round
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE constexpr bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
|
||||
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
|
||||
{
|
||||
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
|
||||
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
|
||||
else return bf8_raw_t{0};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(fp8_raw_t x)
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
@@ -610,7 +610,7 @@ CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(fp8_raw_t x)
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(bf8_raw_t x)
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
@@ -625,23 +625,23 @@ CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(bf8_raw_t x)
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr fp8_t float_to_fp8(float x, constant<rounding> = {})
|
||||
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<fp8_t>(float_to_fp8_raw(x, constant<rounding>{}));
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bf8_t float_to_bf8(float x, constant<rounding> = {})
|
||||
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<bf8_t>(float_to_bf8_raw(x, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float fp8_to_float(fp8_t x)
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
|
||||
{
|
||||
return fp8_to_float_raw(bit_cast<fp8_raw_t>(x));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float bf8_to_float(bf8_t x)
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
|
||||
{
|
||||
return bf8_to_float_raw(bit_cast<bf8_raw_t>(x));
|
||||
}
|
||||
@@ -706,7 +706,14 @@ struct numeric<fp8_t>
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return float_to_fp8(0.5f); }
|
||||
// bin : 7 6543 210
|
||||
// bits: s eeee mmm
|
||||
// 0 0110 000 (0.5)
|
||||
//
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x30));
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t infinity()
|
||||
@@ -766,7 +773,14 @@ struct numeric<bf8_t>
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return float_to_bf8(0.5f); }
|
||||
// bin : 7 65432 10
|
||||
// bits: s eeeee mm
|
||||
// 0 01110 00 (0.5)
|
||||
//
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x38));
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity()
|
||||
|
||||
@@ -184,7 +184,14 @@ struct numeric<half_t>
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return static_cast<half_t>(0.5f); }
|
||||
// bin : f edcba 9876543210
|
||||
// bits: s eeeee mmmmmmmmmm
|
||||
// 0 01110 0000000000 (0.5)
|
||||
//
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t round_error()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x3800));
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t infinity()
|
||||
|
||||
@@ -130,6 +130,7 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16)));
|
||||
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
|
||||
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// f8
|
||||
// using fp8_t
|
||||
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
|
||||
@@ -147,5 +148,24 @@ using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
|
||||
#else
|
||||
// f8
|
||||
// using fp8_t
|
||||
using fp8x2_t = fp8_t __attribute((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_t __attribute((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_t __attribute((ext_vector_type(8)));
|
||||
using fp8x16_t = fp8_t __attribute((ext_vector_type(16)));
|
||||
using fp8x32_t = fp8_t __attribute((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
// bf8
|
||||
// using bf8_t
|
||||
using bf8x2_t = bf8_t __attribute((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_t __attribute((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_t __attribute((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_t __attribute((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -40,8 +40,7 @@ template <typename InElementFunc,
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
const InTensor&... in_dstr_tensors)
|
||||
{
|
||||
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
|
||||
using OutNativeType = typename native_t<OutDataType>::type;
|
||||
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
|
||||
|
||||
// TODO: make sure all distributed tensors have same lengths and distribution
|
||||
// static_assert(xxx);
|
||||
@@ -54,7 +53,7 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
out_dstr_tensor.get_thread_buffer()(i) =
|
||||
static_cast<OutNativeType>(in_element_func(in_dstr_tensors.get_thread_buffer()[i]...));
|
||||
in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
|
||||
});
|
||||
|
||||
return out_dstr_tensor;
|
||||
|
||||
@@ -20,4 +20,3 @@
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
#include "ck_tile/host/reference/reference_softmax.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
|
||||
|
||||
@@ -4,4 +4,3 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
|
||||
@@ -5,4 +5,3 @@
|
||||
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
|
||||
@@ -18,4 +18,3 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user