diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 9e3ccb025d..692d5ec504 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -162,6 +162,16 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q) */ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a) { +#if CK_TILE_USE_OCP_FP8 + // register values [3, 2, 1, 0] + static constexpr uint32_t reg0 = 0xcaccced0; + // register values [7, 6, 5, 4] + static constexpr uint32_t reg1 = 0xb8c0c4c8; + // register values [-1, -2, -3, -4] + static constexpr uint32_t reg2 = 0x44403800; + // register values [-5, -6, -7, -8] + static constexpr uint32_t reg3 = 0x4e4c4a48; +#else // register values [3, 2, 1, 0] static constexpr uint32_t reg0 = 0xd2d4d6d8; // register values [7, 6, 5, 4] @@ -170,6 +180,7 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a) static constexpr uint32_t reg2 = 0x4C484000; // register values [-5, -6, -7, -8] static constexpr uint32_t reg3 = 0x56545250; +#endif uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel; @@ -227,6 +238,16 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src) */ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a) { +#if CK_TILE_USE_OCP_FP8 + // register values [3, 2, 1, 0] + static constexpr uint32_t reg0 = 0Xc5c6c7c8; + // register values [7, 6, 5, 4] + static constexpr uint32_t reg1 = 0Xbcc0c2c4; + // register values [11, 10, 9, 8] + static constexpr uint32_t reg2 = 0X42403c00; + // register values [15, 14, 13, 12] + static constexpr uint32_t reg3 = 0X47464544; +#else // register values [3, 2, 1, 0] static constexpr uint32_t reg0 = 0Xc9cacbcc; // register values [7, 6, 5, 4] @@ -235,6 +256,7 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a) static constexpr uint32_t reg2 = 0X46444000; // register values [15, 14, 13, 12] static constexpr uint32_t reg3 = 0X4b4a4948; +#endif uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;