[CK TILE] GEMM with packed i4 (#1885)

* [CK TILE] GEMM with packed i4

* Fixes

* fixes

* fixes

* fixes
This commit is contained in:
Bartłomiej Kocot
2025-02-20 09:59:49 +01:00
committed by GitHub
parent 824e2c1737
commit 4d9973ec8e
32 changed files with 882 additions and 305 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -9,20 +9,166 @@
namespace ck_tile {
namespace element_wise {
#if 0
// Fast int4x4 to fp16x8_t data type conversion based on paper
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
int lo;
int hi;
// Extract the two int4 at low bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
// Extract the two int4 at hight bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
const int SUB = 0xE408E408; // half2 {-1032, -1032}
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
const int ADD = 0xd480d480; // half2 {-72, -72}
fp16x4_t res;
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(res.lo)
: "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
asm volatile(
"v_pk_fma_f16 %0, %1, %2, %3"
: "=v"(res.hi)
: "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
return res;
}
CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
int lo;
int hi;
// Extract the two int4 at low bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX));
// Extract the two int4 at hight bit and create two fp16 number.
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX));
const int SUB = 0xE408E408; // half2 {-1032, -1032}
const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
const int ADD = 0xd480d480; // half2 {-72, -72}
fp16x4_t res;
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(res.lo)
: "v"(bit_cast<fp16x2_t>(lo)), "v"(bit_cast<fp16x2_t>(SUB)));
asm volatile(
"v_pk_fma_f16 %0, %1, %2, %3"
: "=v"(res.hi)
: "v"(bit_cast<fp16x2_t>(hi)), "v"(bit_cast<fp16x2_t>(MUL)), "v"(bit_cast<fp16x2_t>(ADD)));
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.lo) : "v"(res.lo), "v"(scale));
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.hi) : "v"(res.hi), "v"(scale));
return res;
}
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
{
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388616.f;
fp32_intermediates[1] -= 8388616.f;
fp32_intermediates[2] -= 8388616.f;
fp32_intermediates[3] -= 8388616.f;
bf16x4_t res;
res.lo = bit_cast<bf16x2_t>(
__byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632));
res.hi = bit_cast<bf16x2_t>(
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
return res;
}
struct PassThroughPack8
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_half4(bit_cast<int>(x));
y.hi = i4_to_half4(bit_cast<int>(x) >> 8);
}
CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const
{
y.lo = i4_to_bhalf4(bit_cast<int>(x));
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
}
constexpr const static bool is_pack8_invocable = true;
};
struct DequantPack8
{
template <typename Y, typename X, typename Z>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const;
CK_TILE_HOST_DEVICE constexpr void
operator()(fp16x8_t& y, const pk_int4x4_t& x, const fp16x2_t& z) const
{
y.lo = i4_to_half4_scale(bit_cast<int>(x), z);
y.hi = i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
}
constexpr const static bool is_pack8_invocable = true;
};
struct PassThroughPack2
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
#if 0
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
y = type_convert<fp16x2_t>(t);
}
#endif
CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t& y, const pk_int4_t& x) const
{
uint8_t x_u8 = bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
y.lo = type_convert<half_t>(x_l);
y.hi = type_convert<half_t>(x_h);
}
constexpr const static bool is_pack2_invocable = true;
};
#endif
struct PassThrough
{