mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)
[CK_TILE] MX GEMM non-preshuffled RCR layout ## Motivation Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled layouts using async pipeline. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b8def2c724
commit
8f27f65d44
@@ -666,13 +666,13 @@ struct HostTensor
|
||||
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>)
|
||||
{
|
||||
os << type_convert<float>(mData[idx]) << " #### ";
|
||||
os << type_convert<float>(mData[idx]);
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::pk_int4_t>)
|
||||
{
|
||||
auto unpacked = pk_int4_t_to_int8x2_t(mData[idx]);
|
||||
os << "pk(" << static_cast<int>(unpacked[0]) << ", "
|
||||
<< static_cast<int>(unpacked[1]) << ") #### ";
|
||||
<< static_cast<int>(unpacked[1]) << ")";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, int8_t>)
|
||||
{
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
@@ -471,27 +472,42 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
// HostTensor automatically handles packed indexing: a_m_k(m,k) divides offset by
|
||||
// PackedSize So a_m_k(m,0) and a_m_k(m,1) return the same packed byte
|
||||
const pk_fp4_t pk_val = a_m_k(m, k);
|
||||
const fp32x2_t fp32_val = pk_val.to_fp32x2(1.0f);
|
||||
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(unpacked));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
// HostTensor automatically handles packed indexing
|
||||
const pk_int4_t pk_val = a_m_k(m, k);
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(unpacked));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
// HostTensor automatically handles packed indexing
|
||||
const pk_fp4_t pk_val = b_k_n(k, n);
|
||||
const fp32x2_t fp32_val = pk_val.to_fp32x2(1.0f);
|
||||
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(unpacked));
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
// HostTensor automatically handles packed indexing
|
||||
const pk_int4_t pk_val = b_k_n(k, n);
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
const float unpacked = (k % 2 == 1) ? fp32_val.hi : fp32_val.lo;
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(unpacked));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -671,7 +687,7 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
b_k_n_scaled(k, n) = b_f4_lo * b_scale;
|
||||
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp6x16_t>)
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp6x16_t>)
|
||||
{
|
||||
if(k % pk_fp6x16_t::packed_size != 0)
|
||||
continue;
|
||||
@@ -796,7 +812,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
@@ -816,7 +832,7 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
@@ -908,7 +924,7 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user