mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK TILE] Support fp8/fp16 with pk_int4_t as data types for tensors A and B (#2805)
- Add support for tensor A/B in both fp16+pk_int4_t and fp8+pk_int4_t formats - Implement A(bf8) B(i4) support in universal GEMM - Use new implementation for i4 to fp8 conversion in Block Scale
This commit is contained in:
@@ -4,15 +4,29 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
namespace element_wise {
|
||||
|
||||
// Fast int4x4 to fp16x8_t data type conversion based on paper
|
||||
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
|
||||
// (https://arxiv.org/abs/2211.10017) and implementation:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
/**
|
||||
* @brief Fast int4x4 to fp16x8_t data type conversion based on paper
|
||||
* "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"
|
||||
* @see https://arxiv.org/abs/2211.10017
|
||||
* @see
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
*
|
||||
* This function converts 4 4-bit integers into 4 fp16 values.
|
||||
* @note `int q` contains 4 bytes, low 4 bits of each byte represent an int4.
|
||||
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp16(-8)
|
||||
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
|
||||
* the output sequence will be fp16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
|
||||
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
|
||||
* function.
|
||||
*
|
||||
* @see permute_vectors_i4x4_b
|
||||
*/
|
||||
CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
@@ -46,6 +60,18 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function dequantizes 4 int4 values into 4 fp16 values and applies scaling.
|
||||
*
|
||||
* @note `int q` contains 4 bytes, low 4 bits of each byte represent an int4.
|
||||
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp16(-8)
|
||||
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
|
||||
* the output sequence will be fp16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
|
||||
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
|
||||
* function.
|
||||
*
|
||||
* @see permute_vectors_i4x4_b
|
||||
*/
|
||||
CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
@@ -81,6 +107,18 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function converts 4 4-bit integers into 4 bf16 values.
|
||||
*
|
||||
* @note `int q` contains 4 bytes, low 4 bits of each byte represent an int4.
|
||||
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to bf16(-8)
|
||||
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
|
||||
* the output sequence will be bf16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
|
||||
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
|
||||
* function.
|
||||
*
|
||||
* @see permute_vectors_i4x4_b
|
||||
*/
|
||||
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
|
||||
{
|
||||
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
|
||||
@@ -110,37 +148,55 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function converts 8 packed 4-bit integers into 8 fp8 values.
|
||||
*
|
||||
* @note `int q` contains 4 bytes, each byte represents 2 int4.
|
||||
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp8(-8)
|
||||
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
|
||||
* the output sequence will be fp8(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
|
||||
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
|
||||
* function.
|
||||
*
|
||||
* @see permute_vectors_i4x4_b
|
||||
*/
|
||||
CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
|
||||
{
|
||||
uint32_t src = static_cast<uint32_t>(a), src_hi;
|
||||
uint32_t fp8x4_lo, fp8x4_hi;
|
||||
float tmp_0, tmp_1;
|
||||
// register values [3, 2, 1, 0]
|
||||
static constexpr uint32_t reg0 = 0xd2d4d6d8;
|
||||
// register values [7, 6, 5, 4]
|
||||
static constexpr uint32_t reg1 = 0xc0c8ccd0;
|
||||
// register values [-1, -2, -3, -4]
|
||||
static constexpr uint32_t reg2 = 0x4C484000;
|
||||
// register values [-5, -6, -7, -8]
|
||||
static constexpr uint32_t reg3 = 0x56545250;
|
||||
|
||||
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
uint32_t dict_sel = a & 0x07070707;
|
||||
uint32_t sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
: [v_tmp_0] "+v"(tmp_0),
|
||||
[v_tmp_1] "+v"(tmp_1),
|
||||
[v_hi_src] "+v"(src_hi),
|
||||
[v_dst_lo] "+v"(fp8x4_lo),
|
||||
[v_dst_hi] "+v"(fp8x4_hi),
|
||||
[v_src] "+v"(src)
|
||||
:);
|
||||
a >>= 4;
|
||||
dict_sel = a & 0x07070707;
|
||||
sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
return bit_cast<fp8x8_t>(((static_cast<uint64_t>(fp8x4_hi) << 32) | fp8x4_lo));
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
tmp_res_odd = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
auto tmp_res_low = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x06040200);
|
||||
auto tmp_res_high = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x07050301);
|
||||
|
||||
return bit_cast<fp8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
|
||||
@@ -157,37 +213,55 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(int a)
|
||||
/**
|
||||
* @brief This function converts 8 packed 4-bit integers into 8 bf8 values.
|
||||
*
|
||||
* @note `int q` contains 4 bytes, each byte represents 2 int4.
|
||||
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to bf8(-8)
|
||||
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
|
||||
* the output sequence will be bf8(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
|
||||
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
|
||||
* function.
|
||||
*
|
||||
* @see permute_vectors_i4x4_b
|
||||
*/
|
||||
CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
|
||||
{
|
||||
uint32_t src = static_cast<uint32_t>(a), src_hi;
|
||||
uint32_t bf8x4_lo, bf8x4_hi;
|
||||
float tmp_0, tmp_1;
|
||||
// register values [3, 2, 1, 0]
|
||||
static constexpr uint32_t reg0 = 0Xc9cacbcc;
|
||||
// register values [7, 6, 5, 4]
|
||||
static constexpr uint32_t reg1 = 0Xc0c4c6c8;
|
||||
// register values [11, 10, 9, 8]
|
||||
static constexpr uint32_t reg2 = 0X46444000;
|
||||
// register values [15, 14, 13, 12]
|
||||
static constexpr uint32_t reg3 = 0X4b4a4948;
|
||||
|
||||
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
uint32_t dict_sel = a & 0x07070707;
|
||||
uint32_t sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
: [v_tmp_0] "+v"(tmp_0),
|
||||
[v_tmp_1] "+v"(tmp_1),
|
||||
[v_hi_src] "+v"(src_hi),
|
||||
[v_dst_lo] "+v"(bf8x4_lo),
|
||||
[v_dst_hi] "+v"(bf8x4_hi),
|
||||
[v_src] "+v"(src)
|
||||
:);
|
||||
a >>= 4;
|
||||
dict_sel = a & 0x07070707;
|
||||
sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
return bit_cast<bf8x8_t>(((static_cast<uint64_t>(bf8x4_hi) << 32) | bf8x4_lo));
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
tmp_res_odd = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
auto tmp_res_low = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x06040200);
|
||||
auto tmp_res_high = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x07050301);
|
||||
|
||||
return bit_cast<bf8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
|
||||
}
|
||||
|
||||
struct PassThroughPack8
|
||||
@@ -209,12 +283,12 @@ struct PassThroughPack8
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y = amd_assembly_i4_to_fp8x8(bit_cast<int>(x));
|
||||
y = amd_assembly_i4_to_fp8x8(bit_cast<uint32_t>(x));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y = amd_assembly_i4_to_bf8x8(bit_cast<int>(x));
|
||||
y = amd_assembly_i4_to_bf8x8(bit_cast<uint32_t>(x));
|
||||
}
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user