enable fp4 for universal gemm - without any scaling

This commit is contained in:
Sami Remes
2026-02-03 03:10:35 -05:00
parent 4d241289c9
commit b47853d3fe
8 changed files with 205 additions and 113 deletions

View File

@@ -4,6 +4,7 @@
#pragma once
#include <cstdlib>
#include <mutex>
#include <thread>
#include "ck_tile/core.hpp"
@@ -456,27 +457,42 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
{
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
// HostTensor automatically handles packed indexing: a_m_k(m,k) divides offset by PackedSize
// So a_m_k(m,0) and a_m_k(m,1) return the same packed byte
const pk_fp4_t pk_val = a_m_k(m, k);
const fp32x2_t fp32_val = pk_val.to_fp32x2(1.0f);
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
v_a = ck_tile::type_convert<AccDataType>(a_element_op(unpacked));
}
else if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
// HostTensor automatically handles packed indexing
const pk_int4_t pk_val = a_m_k(m, k);
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
v_a = ck_tile::type_convert<AccDataType>(a_element_op(unpacked));
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
// HostTensor automatically handles packed indexing
const pk_fp4_t pk_val = b_k_n(k, n);
const fp32x2_t fp32_val = pk_val.to_fp32x2(1.0f);
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
v_b = ck_tile::type_convert<AccDataType>(b_element_op(unpacked));
}
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
// HostTensor automatically handles packed indexing
const pk_int4_t pk_val = b_k_n(k, n);
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
v_b = ck_tile::type_convert<AccDataType>(b_element_op(unpacked));
}
else
{
@@ -759,7 +775,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
@@ -779,7 +795,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
@@ -871,7 +887,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
if(k % 2 == 1)
v_a = fp32_val.hi;
else