mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
feat(HostTensor): Extend support for HostTensor class' >> operator to print more data types (#2691)
* feat(check_err): add a variable to adjust number of incorrect values to print * feat(host_tensor): add printing capability for fp8 bf8 int8 int4 * fix(gemm_utils): update acceptable data type * fix(host_tensor): print both 4 bit ints in pk_int4_t * refactor(HostTensor): define pk_int4_t_to_int8x2_t and fix typo in vector_type.hpp * feat(host_tensor): add print first n elements functions
This commit is contained in:
@@ -100,6 +100,7 @@ struct numeric_traits<pk_int4_t>
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
using int8x2_t = int8_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
|
||||
{
|
||||
@@ -165,4 +166,24 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x)
|
||||
{
|
||||
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
|
||||
|
||||
int8_t x_l = (x_u8 & 0x0F);
|
||||
int8_t x_h = (x_u8 & 0xF0) >> 4;
|
||||
|
||||
if(x_l & 0x08)
|
||||
x_l |= 0xF0;
|
||||
if(x_h & 0x08)
|
||||
x_h |= 0xF0;
|
||||
|
||||
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
|
||||
int8x2_t res = {x_h, x_l};
|
||||
#else
|
||||
int8x2_t res = {x_l, x_h};
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -176,65 +176,65 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i8
|
||||
// using int8_t
|
||||
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
|
||||
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
|
||||
using int8x8_t = int8_t __attribute((ext_vector_type(8)));
|
||||
using int8x16_t = int8_t __attribute((ext_vector_type(16)));
|
||||
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
|
||||
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
|
||||
using int8x2_t = int8_t __attribute__((ext_vector_type(2)));
|
||||
using int8x4_t = int8_t __attribute__((ext_vector_type(4)));
|
||||
using int8x8_t = int8_t __attribute__((ext_vector_type(8)));
|
||||
using int8x16_t = int8_t __attribute__((ext_vector_type(16)));
|
||||
using int8x32_t = int8_t __attribute__((ext_vector_type(32)));
|
||||
using int8x64_t = int8_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// ui8
|
||||
// using uint8_t
|
||||
using uint8x2_t = uint8_t __attribute((ext_vector_type(2)));
|
||||
using uint8x4_t = uint8_t __attribute((ext_vector_type(4)));
|
||||
using uint8x8_t = uint8_t __attribute((ext_vector_type(8)));
|
||||
using uint8x16_t = uint8_t __attribute((ext_vector_type(16)));
|
||||
using uint8x32_t = uint8_t __attribute((ext_vector_type(32)));
|
||||
using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
|
||||
using uint8x2_t = uint8_t __attribute__((ext_vector_type(2)));
|
||||
using uint8x4_t = uint8_t __attribute__((ext_vector_type(4)));
|
||||
using uint8x8_t = uint8_t __attribute__((ext_vector_type(8)));
|
||||
using uint8x16_t = uint8_t __attribute__((ext_vector_type(16)));
|
||||
using uint8x32_t = uint8_t __attribute__((ext_vector_type(32)));
|
||||
using uint8x64_t = uint8_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// f8
|
||||
// using fp8_t
|
||||
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
|
||||
using fp8x16_t = fp8_raw_t __attribute((ext_vector_type(16)));
|
||||
using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
|
||||
using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_raw_t __attribute__((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_raw_t __attribute__((ext_vector_type(8)));
|
||||
using fp8x16_t = fp8_raw_t __attribute__((ext_vector_type(16)));
|
||||
using fp8x32_t = fp8_raw_t __attribute__((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_raw_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bf8
|
||||
// using bf8_t
|
||||
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
|
||||
using bf8x2_t = bf8_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_raw_t __attribute__((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_raw_t __attribute__((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_raw_t __attribute__((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_raw_t __attribute__((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_raw_t __attribute__((ext_vector_type(64)));
|
||||
#else
|
||||
// f8
|
||||
// using fp8_t
|
||||
using fp8x2_t = fp8_t __attribute((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_t __attribute((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_t __attribute((ext_vector_type(8)));
|
||||
using fp8x16_t = fp8_t __attribute((ext_vector_type(16)));
|
||||
using fp8x32_t = fp8_t __attribute((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
|
||||
using fp8x2_t = fp8_t __attribute__((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_t __attribute__((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_t __attribute__((ext_vector_type(8)));
|
||||
using fp8x16_t = fp8_t __attribute__((ext_vector_type(16)));
|
||||
using fp8x32_t = fp8_t __attribute__((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bf8
|
||||
// using bf8_t
|
||||
using bf8x2_t = bf8_t __attribute((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_t __attribute((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_t __attribute((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_t __attribute((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
|
||||
using bf8x2_t = bf8_t __attribute__((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_t __attribute__((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_t __attribute__((ext_vector_type(8)));
|
||||
using bf8x16_t = bf8_t __attribute__((ext_vector_type(16)));
|
||||
using bf8x32_t = bf8_t __attribute__((ext_vector_type(32)));
|
||||
using bf8x64_t = bf8_t __attribute__((ext_vector_type(64)));
|
||||
#endif
|
||||
|
||||
// pk_int4_t
|
||||
// using pk_int4_t
|
||||
using pk_int4x2_t = int8_t __attribute((ext_vector_type(2)));
|
||||
using pk_int4x4_t = int8_t __attribute((ext_vector_type(4)));
|
||||
using pk_int4x8_t = int8_t __attribute((ext_vector_type(8)));
|
||||
using pk_int4x16_t = int8_t __attribute((ext_vector_type(16)));
|
||||
using pk_int4x32_t = int8_t __attribute((ext_vector_type(32)));
|
||||
using pk_int4x2_t = int8_t __attribute__((ext_vector_type(2)));
|
||||
using pk_int4x4_t = int8_t __attribute__((ext_vector_type(4)));
|
||||
using pk_int4x8_t = int8_t __attribute__((ext_vector_type(8)));
|
||||
using pk_int4x16_t = int8_t __attribute__((ext_vector_type(16)));
|
||||
using pk_int4x32_t = int8_t __attribute__((ext_vector_type(32)));
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user