feat(HostTensor): Extend support for HostTensor class' >> operator to print more data types (#2691)

* feat(check_err): add a variable to adjust number of incorrect values to print

* feat(host_tensor): add printing capability for fp8 bf8 int8 int4

* fix(gemm_utils): update acceptable data type

* fix(host_tensor): print both 4 bit ints in pk_int4_t

* refactor(HostTensor): define pk_int4_t_to_int8x2_t and fix typo in vector_type.hpp

* feat(host_tensor): add print first n elements functions
This commit is contained in:
Aviral Goel
2025-08-27 21:17:24 -04:00
committed by GitHub
parent 9751583f95
commit f5f795c4d6
5 changed files with 125 additions and 49 deletions

View File

@@ -100,6 +100,7 @@ struct numeric_traits<pk_int4_t>
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
using int8x2_t = int8_t __attribute__((ext_vector_type(2)));
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
{
@@ -165,4 +166,24 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
return res;
}
CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
int8_t x_l = (x_u8 & 0x0F);
int8_t x_h = (x_u8 & 0xF0) >> 4;
if(x_l & 0x08)
x_l |= 0xF0;
if(x_h & 0x08)
x_h |= 0xF0;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
int8x2_t res = {x_h, x_l};
#else
int8x2_t res = {x_l, x_h};
#endif
return res;
}
} // namespace ck_tile

View File

@@ -176,65 +176,65 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8
// using int8_t
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
using int8x8_t = int8_t __attribute((ext_vector_type(8)));
using int8x16_t = int8_t __attribute((ext_vector_type(16)));
using int8x32_t = int8_t __attribute((ext_vector_type(32)));
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
using int8x2_t = int8_t __attribute__((ext_vector_type(2)));
using int8x4_t = int8_t __attribute__((ext_vector_type(4)));
using int8x8_t = int8_t __attribute__((ext_vector_type(8)));
using int8x16_t = int8_t __attribute__((ext_vector_type(16)));
using int8x32_t = int8_t __attribute__((ext_vector_type(32)));
using int8x64_t = int8_t __attribute__((ext_vector_type(64)));
// ui8
// using uint8_t
using uint8x2_t = uint8_t __attribute((ext_vector_type(2)));
using uint8x4_t = uint8_t __attribute((ext_vector_type(4)));
using uint8x8_t = uint8_t __attribute((ext_vector_type(8)));
using uint8x16_t = uint8_t __attribute((ext_vector_type(16)));
using uint8x32_t = uint8_t __attribute((ext_vector_type(32)));
using uint8x64_t = uint8_t __attribute((ext_vector_type(64)));
using uint8x2_t = uint8_t __attribute__((ext_vector_type(2)));
using uint8x4_t = uint8_t __attribute__((ext_vector_type(4)));
using uint8x8_t = uint8_t __attribute__((ext_vector_type(8)));
using uint8x16_t = uint8_t __attribute__((ext_vector_type(16)));
using uint8x32_t = uint8_t __attribute__((ext_vector_type(32)));
using uint8x64_t = uint8_t __attribute__((ext_vector_type(64)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// using fp8_t
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
using fp8x16_t = fp8_raw_t __attribute((ext_vector_type(16)));
using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32)));
using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2)));
using fp8x4_t = fp8_raw_t __attribute__((ext_vector_type(4)));
using fp8x8_t = fp8_raw_t __attribute__((ext_vector_type(8)));
using fp8x16_t = fp8_raw_t __attribute__((ext_vector_type(16)));
using fp8x32_t = fp8_raw_t __attribute__((ext_vector_type(32)));
using fp8x64_t = fp8_raw_t __attribute__((ext_vector_type(64)));
// bf8
// using bf8_t
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16)));
using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64)));
using bf8x2_t = bf8_raw_t __attribute__((ext_vector_type(2)));
using bf8x4_t = bf8_raw_t __attribute__((ext_vector_type(4)));
using bf8x8_t = bf8_raw_t __attribute__((ext_vector_type(8)));
using bf8x16_t = bf8_raw_t __attribute__((ext_vector_type(16)));
using bf8x32_t = bf8_raw_t __attribute__((ext_vector_type(32)));
using bf8x64_t = bf8_raw_t __attribute__((ext_vector_type(64)));
#else
// f8
// using fp8_t
using fp8x2_t = fp8_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_t __attribute((ext_vector_type(8)));
using fp8x16_t = fp8_t __attribute((ext_vector_type(16)));
using fp8x32_t = fp8_t __attribute((ext_vector_type(32)));
using fp8x64_t = fp8_t __attribute((ext_vector_type(64)));
using fp8x2_t = fp8_t __attribute__((ext_vector_type(2)));
using fp8x4_t = fp8_t __attribute__((ext_vector_type(4)));
using fp8x8_t = fp8_t __attribute__((ext_vector_type(8)));
using fp8x16_t = fp8_t __attribute__((ext_vector_type(16)));
using fp8x32_t = fp8_t __attribute__((ext_vector_type(32)));
using fp8x64_t = fp8_t __attribute__((ext_vector_type(64)));
// bf8
// using bf8_t
using bf8x2_t = bf8_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_t __attribute((ext_vector_type(8)));
using bf8x16_t = bf8_t __attribute((ext_vector_type(16)));
using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
using bf8x2_t = bf8_t __attribute__((ext_vector_type(2)));
using bf8x4_t = bf8_t __attribute__((ext_vector_type(4)));
using bf8x8_t = bf8_t __attribute__((ext_vector_type(8)));
using bf8x16_t = bf8_t __attribute__((ext_vector_type(16)));
using bf8x32_t = bf8_t __attribute__((ext_vector_type(32)));
using bf8x64_t = bf8_t __attribute__((ext_vector_type(64)));
#endif
// pk_int4_t
// using pk_int4_t
using pk_int4x2_t = int8_t __attribute((ext_vector_type(2)));
using pk_int4x4_t = int8_t __attribute((ext_vector_type(4)));
using pk_int4x8_t = int8_t __attribute((ext_vector_type(8)));
using pk_int4x16_t = int8_t __attribute((ext_vector_type(16)));
using pk_int4x32_t = int8_t __attribute((ext_vector_type(32)));
using pk_int4x2_t = int8_t __attribute__((ext_vector_type(2)));
using pk_int4x4_t = int8_t __attribute__((ext_vector_type(4)));
using pk_int4x8_t = int8_t __attribute__((ext_vector_type(8)));
using pk_int4x16_t = int8_t __attribute__((ext_vector_type(16)));
using pk_int4x32_t = int8_t __attribute__((ext_vector_type(32)));
} // namespace ck_tile