mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
enable fp4 for universal gemm - without any scaling
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
@@ -456,27 +457,42 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
// HostTensor automatically handles packed indexing: a_m_k(m,k) divides offset by PackedSize
|
||||
// So a_m_k(m,0) and a_m_k(m,1) return the same packed byte
|
||||
const pk_fp4_t pk_val = a_m_k(m, k);
|
||||
const fp32x2_t fp32_val = pk_val.to_fp32x2(1.0f);
|
||||
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(unpacked));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
// HostTensor automatically handles packed indexing
|
||||
const pk_int4_t pk_val = a_m_k(m, k);
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(unpacked));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
// HostTensor automatically handles packed indexing
|
||||
const pk_fp4_t pk_val = b_k_n(k, n);
|
||||
const fp32x2_t fp32_val = pk_val.to_fp32x2(1.0f);
|
||||
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(unpacked));
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
// HostTensor automatically handles packed indexing
|
||||
const pk_int4_t pk_val = b_k_n(k, n);
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(unpacked));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -759,7 +775,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
@@ -779,7 +795,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
@@ -871,7 +887,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
|
||||
@@ -700,9 +700,6 @@ struct UniversalGemmKernel
|
||||
[&](auto i) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<i.value, AsLayout>>;
|
||||
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
|
||||
static_assert(GemmPipeline::GetVectorSizeA() == GemmPipeline::GetVectorSizeB(), "Vector size of A and B must be the same!");
|
||||
static_assert(GemmPipeline::GetVectorSizeA() == 32, "Vector size of A must be 16!");
|
||||
static_assert(GemmPipeline::GetVectorSizeB() == 32, "Vector size of B must be 16!");
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
|
||||
@@ -84,6 +84,11 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
"Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported");
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GetName()
|
||||
{
|
||||
return "COMPUTE_ASYNC";
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -517,13 +517,7 @@ struct UniversalGemmBasePolicy
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
// Assume DataType is even!
|
||||
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
PackedSize == 2)
|
||||
{
|
||||
return (PackedSize * 32 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
|
||||
if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 16 / sizeof(DataType));
|
||||
@@ -846,9 +840,10 @@ struct UniversalGemmBasePolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_a = integer_least_multiple(
|
||||
a_lds_block_desc.get_element_space_size() * sizeof(ADataType), 16);
|
||||
a_lds_block_desc.get_element_space_size() * sizeof(ADataType) / APackedSize, 16);
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
@@ -859,9 +854,10 @@ struct UniversalGemmBasePolicy
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16);
|
||||
b_lds_block_desc.get_element_space_size() * sizeof(BDataType) / BPackedSize, 16);
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user