mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK TILE] GEMM with packed i4 (#1885)
* [CK TILE] GEMM with packed i4 * Fixes * fixes * fixes * fixes
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -9,20 +9,166 @@
|
||||
namespace ck_tile {
|
||||
namespace element_wise {
|
||||
|
||||
#if 0
|
||||
// Fast int4x4 to fp16x8_t data type conversion based on paper
|
||||
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
|
||||
// (https://arxiv.org/abs/2211.10017) and implementation:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
|
||||
int lo;
|
||||
int hi;
|
||||
// Extract the two int4 at low bit and create two fp16 number.
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
|
||||
// Extract the two int4 at hight bit and create two fp16 number.
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
|
||||
|
||||
const int SUB = 0xE408E408; // half2 {-1032, -1032}
|
||||
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
|
||||
const int ADD = 0xd480d480; // half2 {-72, -72}
|
||||
|
||||
fp16x4_t res;
|
||||
|
||||
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(res.lo)
|
||||
: "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
|
||||
|
||||
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
|
||||
asm volatile(
|
||||
"v_pk_fma_f16 %0, %1, %2, %3"
|
||||
: "=v"(res.hi)
|
||||
: "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
|
||||
{
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
|
||||
int lo;
|
||||
int hi;
|
||||
// Extract the two int4 at low bit and create two fp16 number.
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
|
||||
// Extract the two int4 at hight bit and create two fp16 number.
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
|
||||
|
||||
const int SUB = 0xE408E408; // half2 {-1032, -1032}
|
||||
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
|
||||
const int ADD = 0xd480d480; // half2 {-72, -72}
|
||||
|
||||
fp16x4_t res;
|
||||
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(res.lo)
|
||||
: "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
|
||||
|
||||
asm volatile(
|
||||
"v_pk_fma_f16 %0, %1, %2, %3"
|
||||
: "=v"(res.hi)
|
||||
: "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
|
||||
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.lo) : "v"(res.lo), "v"(scale));
|
||||
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.hi) : "v"(res.hi), "v"(scale));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
|
||||
{
|
||||
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388616.f;
|
||||
fp32_intermediates[1] -= 8388616.f;
|
||||
fp32_intermediates[2] -= 8388616.f;
|
||||
fp32_intermediates[3] -= 8388616.f;
|
||||
|
||||
bf16x4_t res;
|
||||
res.lo = bit_cast<bf16x2_t>(
|
||||
__byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632));
|
||||
res.hi = bit_cast<bf16x2_t>(
|
||||
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
struct PassThroughPack8
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y.lo = i4_to_half4(bit_cast<int>(x));
|
||||
y.hi = i4_to_half4(bit_cast<int>(x) >> 8);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y.lo = i4_to_bhalf4(bit_cast<int>(x));
|
||||
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
|
||||
}
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
struct DequantPack8
|
||||
{
|
||||
template <typename Y, typename X, typename Z>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(fp16x8_t& y, const pk_int4x4_t& x, const fp16x2_t& z) const
|
||||
{
|
||||
y.lo = i4_to_half4_scale(bit_cast<int>(x), z);
|
||||
y.hi = i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
struct PassThroughPack2
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
|
||||
#if 0
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const
|
||||
{
|
||||
auto t = type_convert<float2_t>(x);
|
||||
y = type_convert<half2_t>(t);
|
||||
y = type_convert<fp16x2_t>(t);
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t& y, const pk_int4_t& x) const
|
||||
{
|
||||
uint8_t x_u8 = bit_cast<uint8_t>(x);
|
||||
uint8_t x_l = (x_u8 & 0x0f) >> 0;
|
||||
uint8_t x_h = (x_u8 & 0xf0) >> 4;
|
||||
|
||||
y.lo = type_convert<half_t>(x_l);
|
||||
y.hi = type_convert<half_t>(x_h);
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack2_invocable = true;
|
||||
};
|
||||
#endif
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user