fix compile error, fp8 not ready now

This commit is contained in:
carlushuang
2024-03-18 07:58:00 +00:00
parent f55c7629bc
commit 886d040a81
10 changed files with 67 additions and 24 deletions

View File

@@ -54,4 +54,3 @@
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"

View File

@@ -267,7 +267,15 @@ struct numeric<bfloat16_t>
}
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return float_to_bf16(0.5f); }
// maximum rounding error
// bin : f edcba 9876543210
// bits: s eeeeeeee mmmmmmm
// 0 01111110 0000000 (0.5)
//
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()

View File

@@ -42,13 +42,13 @@ enum class fp8_rounding_mode
*/
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_fp8_raw(float, constant<rounding> = {});
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_bf8_raw(float, constant<rounding> = {});
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(uint8_t);
CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(uint8_t);
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t);
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
struct alignas(1) float8_e4m3_t
@@ -581,7 +581,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
// clang-format off
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE constexpr fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
@@ -589,14 +589,14 @@ CK_TILE_HOST_DEVICE constexpr fp8_raw_t float_to_fp8_raw(float x, constant<round
}
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE constexpr bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
else return bf8_raw_t{0};
}
CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(fp8_raw_t x)
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
@@ -610,7 +610,7 @@ CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(fp8_raw_t x)
#endif
}
CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(bf8_raw_t x)
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
@@ -625,23 +625,23 @@ CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(bf8_raw_t x)
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr fp8_t float_to_fp8(float x, constant<rounding> = {})
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant<rounding> = {})
{
return bit_cast<fp8_t>(float_to_fp8_raw(x, constant<rounding>{}));
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bf8_t float_to_bf8(float x, constant<rounding> = {})
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant<rounding> = {})
{
return bit_cast<bf8_t>(float_to_bf8_raw(x, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE constexpr float fp8_to_float(fp8_t x)
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
{
return fp8_to_float_raw(bit_cast<fp8_raw_t>(x));
}
CK_TILE_HOST_DEVICE constexpr float bf8_to_float(bf8_t x)
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
{
return bf8_to_float_raw(bit_cast<bf8_raw_t>(x));
}
@@ -706,7 +706,14 @@ struct numeric<fp8_t>
}
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return float_to_fp8(0.5f); }
// bin : 7 6543 210
// bits: s eeee mmm
// 0 0110 000 (0.5)
//
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x30));
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr fp8_t infinity()
@@ -766,7 +773,14 @@ struct numeric<bf8_t>
}
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return float_to_bf8(0.5f); }
// bin : 7 65432 10
// bits: s eeeee mm
// 0 01110 00 (0.5)
//
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x38));
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity()

View File

@@ -184,7 +184,14 @@ struct numeric<half_t>
}
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return static_cast<half_t>(0.5f); }
// bin : f edcba 9876543210
// bits: s eeeee mmmmmmmmmm
// 0 01110 0000000000 (0.5)
//
CK_TILE_HOST_DEVICE static constexpr half_t round_error()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x3800));
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr half_t infinity()

View File

@@ -130,6 +130,7 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16)));
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// using fp8_t
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
@@ -147,5 +148,24 @@ using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16)));
using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
#else
// f8
// using fp8_t
using fp8x2_t = fp8_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_t __attribute((ext_vector_type(8)));
using fp8x16_t = fp8_t __attribute((ext_vector_type(16)));
using fp8x32_t = fp8_t __attribute((ext_vector_type(32)));
using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
// bf8
// using bf8_t
using bf8x2_t = bf8_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_t __attribute((ext_vector_type(8)));
using bf8x16_t = bf8_t __attribute((ext_vector_type(16)));
using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
#endif
} // namespace ck_tile

View File

@@ -40,8 +40,7 @@ template <typename InElementFunc,
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
const InTensor&... in_dstr_tensors)
{
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
using OutNativeType = typename native_t<OutDataType>::type;
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
@@ -54,7 +53,7 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
out_dstr_tensor.get_thread_buffer()(i) =
static_cast<OutNativeType>(in_element_func(in_dstr_tensors.get_thread_buffer()[i]...));
in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
});
return out_dstr_tensor;

View File

@@ -20,4 +20,3 @@
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"

View File

@@ -4,4 +4,3 @@
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -5,4 +5,3 @@
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -18,4 +18,3 @@
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"