[CK TILE] Support fp8/fp16 with pk_int4_t as data types for tensors A and B (#2805)

- Add support for tensor A/B in both fp16+pk_int4_t and fp8+pk_int4_t formats
- Implement A(bf8) B(i4) support in universal GEMM
- Use new implementation for i4 to fp8 conversion in Block Scale
This commit is contained in:
Cong Ma
2025-09-09 17:40:52 -06:00
committed by GitHub
parent 91178b4011
commit 82890192dd
15 changed files with 320 additions and 135 deletions

View File

@@ -344,6 +344,24 @@ struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{

View File

@@ -1,6 +1,8 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
@@ -90,61 +92,6 @@ void permute_tensor_b(Tensor& tensor)
}
}
template <typename Tensor>
void permute_vectors_i4x4_b(Tensor& tensor)
{
const ck_tile::index_t K = tensor.get_length(0);
const ck_tile::index_t N = tensor.get_length(1);
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int8_t input[8];
for(int k = 0; k < 4; k++)
{
int8_t i4x2 = tensor(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int8_t hi = input[2];
int8_t lo = input[0];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 0, i) = i4x2;
}
{
int8_t hi = input[6];
int8_t lo = input[4];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 2, i) = i4x2;
}
{
int8_t hi = input[3];
int8_t lo = input[1];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 4, i) = i4x2;
}
{
int8_t hi = input[7];
int8_t lo = input[5];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 6, i) = i4x2;
}
}
}
}
template <typename GemmConfig,
typename Invoker,
typename ADataType,
@@ -399,7 +346,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
BLayout,
CLayout>(b_k_n_dev);
}
permute_vectors_i4x4_b(b_k_n_dev);
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else

View File

@@ -5,7 +5,7 @@ KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
run_tests() {
for m in 512 1024; do
@@ -32,5 +32,8 @@ run_tests "fp16"
run_tests "bf16"
run_tests "fp8"
run_tests "bf8"
run_tests "fp16i4"
run_tests "fp8i4"
run_tests "bf8i4"
set +x

View File

@@ -5,11 +5,8 @@
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
#include "run_gemm_example_common.hpp"
@@ -58,7 +55,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
ck_tile::int8_t,
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "pk_int4_t")
else if(data_type == "fp16i4")
{
// TODO: Add support for bhalf_t ADataType
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
@@ -74,6 +71,36 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
else if(data_type == "fp8i4")
{
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
Invoker,
ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
else if(data_type == "bf8i4")
{
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
Invoker,
ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");

View File

@@ -228,4 +228,4 @@ int run_gemm_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }

View File

@@ -5,6 +5,7 @@
#pragma once
#include <random>
#include <stdexcept>
#include "../00_shared/host_tensor_utils.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
@@ -217,7 +218,16 @@ int run_gemm_example_with_layouts(int argc,
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
}
a_m_k_dev_buf.ToDevice(a_m_k.data());
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();

View File

@@ -3,6 +3,7 @@
#pragma once
#include <random>
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
@@ -208,7 +209,17 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();

View File

@@ -4,6 +4,7 @@
#pragma once
#include <random>
#include <stdexcept>
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
@@ -308,7 +309,17 @@ int run_gemm_example_with_layouts(int argc,
aq_dev_buf.ToDevice(aq_tensor.data());
}
a_m_k_dev_buf.ToDevice(a_m_k.data());
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();

View File

@@ -125,7 +125,7 @@ CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_in
float x_h = ((x_u8 & 0xf0) >> 4);
x_l = x_l > 7 ? x_l - 16 : x_l;
x_h = x_l > 7 ? x_l - 16 : x_l;
x_h = x_h > 7 ? x_h - 16 : x_h;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
fp32x2_t res = {x_h, x_l};

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c), Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/bit_cast.hpp"
namespace ck_tile {
/**
* @brief Permute packed int4 vectors for device implementation compatibility
*
* This function transforms 4 pk_int4_t values from original layout to hardware-optimized layout:
* - Original layout (4 pk_int4_t): 0x76543210
* - Transformed layout (4 pk_int4_t): 0x75316420
*
* Each pk_int4_t contains two 4-bit values packed in the high and low nibbles of an int8_t
*
* Example:
* - Input: 0x76, 0x54, 0x32, 0x10
* - Output: 0x75, 0x31, 0x64, 0x20
*
* @note Input tensor length must be a multiple of 4
*
* This transformation is required before transferring B matrix data (of type pk_int4_t) to device.
* The device conversion functions (i4_to_half4, i4_to_bhalf4, amd_assembly_i4_to_fp8x8,
* amd_assembly_i4_to_bf8x8) require data in 0x75316420 order to correctly convert pk_int4_t to
* other numeric types.
*/
template <typename Tensor>
void permute_vectors_i4x4_b(Tensor& tensor)
{
auto tensor_row_buf = tensor.data();
for(size_t idx = 0; idx < tensor.size(); idx += 4)
{
int8_t input[8];
for(int k = 0; k < 4; k++)
{
int8_t i4x2 = bit_cast<int8_t>(tensor_row_buf[idx + k]);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 0x76543210 => 0x75316420
{
int8_t hi = input[2];
int8_t lo = input[0];
int8_t i4x2 = (hi << 4) | lo;
tensor_row_buf[idx + 0] = bit_cast<pk_int4_t>(i4x2);
}
{
int8_t hi = input[6];
int8_t lo = input[4];
int8_t i4x2 = (hi << 4) | lo;
tensor_row_buf[idx + 1] = bit_cast<pk_int4_t>(i4x2);
}
{
int8_t hi = input[3];
int8_t lo = input[1];
int8_t i4x2 = (hi << 4) | lo;
tensor_row_buf[idx + 2] = bit_cast<pk_int4_t>(i4x2);
}
{
int8_t hi = input[7];
int8_t lo = input[5];
int8_t i4x2 = (hi << 4) | lo;
tensor_row_buf[idx + 3] = bit_cast<pk_int4_t>(i4x2);
}
}
}
} // namespace ck_tile

View File

@@ -50,7 +50,7 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
@@ -63,7 +63,7 @@ CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else

View File

@@ -4,15 +4,29 @@
#pragma once
#include "ck_tile/core.hpp"
#include <cstdint>
#include <type_traits>
namespace ck_tile {
namespace element_wise {
// Fast int4x4 to fp16x8_t data type conversion based on paper
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
/**
* @brief Fast int4x4 to fp16x8_t data type conversion based on paper
* "Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"
* @see https://arxiv.org/abs/2211.10017
* @see
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
*
* This function converts 4 4-bit integers into 4 fp16 values.
* @note `int q` contains 4 bytes, low 4 bits of each byte represent an int4.
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp16(-8)
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
* the output sequence will be fp16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
* function.
*
* @see permute_vectors_i4x4_b
*/
CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
{
const int LO = 0x000f000f;
@@ -46,6 +60,18 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4(int q)
return res;
}
/**
* @brief This function dequantizes 4 int4 values into 4 fp16 values and applies scaling.
*
* @note `int q` contains 4 bytes, low 4 bits of each byte represent an int4.
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp16(-8)
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
* the output sequence will be fp16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
* function.
*
* @see permute_vectors_i4x4_b
*/
CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
{
const int LO = 0x000f000f;
@@ -81,6 +107,18 @@ CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale)
return res;
}
/**
* @brief This function converts 4 4-bit integers into 4 bf16 values.
*
* @note `int q` contains 4 bytes, low 4 bits of each byte represent an int4.
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to bf16(-8)
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
* the output sequence will be bf16(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
* function.
*
* @see permute_vectors_i4x4_b
*/
CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
{
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
@@ -110,37 +148,55 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
return res;
}
/**
* @brief This function converts 8 packed 4-bit integers into 8 fp8 values.
*
* @note `int q` contains 4 bytes, each byte represents 2 int4.
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to fp8(-8)
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
* the output sequence will be fp8(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
* function.
*
* @see permute_vectors_i4x4_b
*/
CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
{
uint32_t src = static_cast<uint32_t>(a), src_hi;
uint32_t fp8x4_lo, fp8x4_hi;
float tmp_0, tmp_1;
// register values [3, 2, 1, 0]
static constexpr uint32_t reg0 = 0xd2d4d6d8;
// register values [7, 6, 5, 4]
static constexpr uint32_t reg1 = 0xc0c8ccd0;
// register values [-1, -2, -3, -4]
static constexpr uint32_t reg2 = 0x4C484000;
// register values [-5, -6, -7, -8]
static constexpr uint32_t reg3 = 0x56545250;
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
uint32_t dict_sel = a & 0x07070707;
uint32_t sign = a >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
: [v_tmp_0] "+v"(tmp_0),
[v_tmp_1] "+v"(tmp_1),
[v_hi_src] "+v"(src_hi),
[v_dst_lo] "+v"(fp8x4_lo),
[v_dst_hi] "+v"(fp8x4_hi),
[v_src] "+v"(src)
:);
a >>= 4;
dict_sel = a & 0x07070707;
sign = a >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
return bit_cast<fp8x8_t>(((static_cast<uint64_t>(fp8x4_hi) << 32) | fp8x4_lo));
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
tmp_res_odd = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
auto tmp_res_low = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x06040200);
auto tmp_res_high = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x07050301);
return bit_cast<fp8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
}
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
@@ -157,37 +213,55 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
return res;
}
CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(int a)
/**
* @brief This function converts 8 packed 4-bit integers into 8 bf8 values.
*
* @note `int q` contains 4 bytes, each byte represents 2 int4.
* @note This function assumes pk_int4_t has a bias of 8, meaning 0b0000 is converted to bf8(-8)
* @note The output ordering differs from input ordering. For example, when input is 0x76543210,
* the output sequence will be bf8(7, 3, 6, 2, 5, 1, 4, 0). Therefore, the input tensor
* must be preprocessed with permute_vectors_i4x4_b on the host side before using this
* function.
*
* @see permute_vectors_i4x4_b
*/
CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
{
uint32_t src = static_cast<uint32_t>(a), src_hi;
uint32_t bf8x4_lo, bf8x4_hi;
float tmp_0, tmp_1;
// register values [3, 2, 1, 0]
static constexpr uint32_t reg0 = 0Xc9cacbcc;
// register values [7, 6, 5, 4]
static constexpr uint32_t reg1 = 0Xc0c4c6c8;
// register values [11, 10, 9, 8]
static constexpr uint32_t reg2 = 0X46444000;
// register values [15, 14, 13, 12]
static constexpr uint32_t reg3 = 0X4b4a4948;
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
uint32_t dict_sel = a & 0x07070707;
uint32_t sign = a >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
: [v_tmp_0] "+v"(tmp_0),
[v_tmp_1] "+v"(tmp_1),
[v_hi_src] "+v"(src_hi),
[v_dst_lo] "+v"(bf8x4_lo),
[v_dst_hi] "+v"(bf8x4_hi),
[v_src] "+v"(src)
:);
a >>= 4;
dict_sel = a & 0x07070707;
sign = a >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
return bit_cast<bf8x8_t>(((static_cast<uint64_t>(bf8x4_hi) << 32) | bf8x4_lo));
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
tmp_res_odd = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
auto tmp_res_low = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x06040200);
auto tmp_res_high = __builtin_amdgcn_perm(tmp_res_odd, tmp_res_even, 0x07050301);
return bit_cast<bf8x8_t>((static_cast<uint64_t>(tmp_res_high) << 32) | tmp_res_low);
}
struct PassThroughPack8
@@ -209,12 +283,12 @@ struct PassThroughPack8
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
{
y = amd_assembly_i4_to_fp8x8(bit_cast<int>(x));
y = amd_assembly_i4_to_fp8x8(bit_cast<uint32_t>(x));
}
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
{
y = amd_assembly_i4_to_bf8x8(bit_cast<int>(x));
y = amd_assembly_i4_to_bf8x8(bit_cast<uint32_t>(x));
}
constexpr const static bool is_pack8_invocable = true;
};

View File

@@ -181,9 +181,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr auto Scheduler = Traits::Scheduler;
static constexpr uint8_t kA_cvt_scale = std::is_same_v<ADataType, pk_int4_t> ? 16 : 1;
static constexpr uint8_t kB_cvt_scale = std::is_same_v<BDataType, pk_int4_t> ? 16 : 1;
static constexpr auto Scheduler = Traits::Scheduler;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
@@ -451,7 +449,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
scale_reg_f);
});
}
}
@@ -471,7 +469,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
scale_reg_f);
});
}
else
@@ -556,7 +554,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
reg_offset_for_row_data] +=
(c_warp_tensor
.get_thread_buffer()[reg_offset_for_row_data] *
scale_reg_f * kA_cvt_scale * kB_cvt_scale);
scale_reg_f);
});
}
}

View File

@@ -179,9 +179,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr auto Scheduler = Traits::Scheduler;
static constexpr uint8_t kA_cvt_scale = std::is_same_v<ADataType, pk_int4_t> ? 16 : 1;
static constexpr uint8_t kB_cvt_scale = std::is_same_v<BDataType, pk_int4_t> ? 16 : 1;
static constexpr auto Scheduler = Traits::Scheduler;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
@@ -384,8 +382,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f *
kA_cvt_scale * kB_cvt_scale);
(c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
});
});
});

View File

@@ -14,6 +14,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "test_gemm_aquant_utils.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename GemmConfig,
typename ADataType,
@@ -336,7 +337,17 @@ bool run_gemm_test_with_layouts(int argc,
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();