update reference gemm mx

This commit is contained in:
mtgu0705
2025-08-05 05:22:59 -05:00
parent 9c40290dae
commit 1afa2b61e3
3 changed files with 316 additions and 222 deletions

View File

@@ -154,7 +154,10 @@ float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::strea
#include "run_gemm_mx_example.inc"
template <typename TypeConfig, uint32_t BlockScaleSize>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
int run_gemm_mx_example_prec_type(std::string a_layout,
std::string b_layout,
int argc,
char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
@@ -163,7 +166,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<TypeConfig, BlockScaleSize>(
return run_gemm_mx_example_with_layouts<TypeConfig, BlockScaleSize>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
else
@@ -196,7 +199,7 @@ int run_gemm_mx_example(int argc, char* argv[])
ck_tile::e8m0_bexp_t,
int32_t,
ck_tile::half_t>{});
return run_gemm_example_prec_type<TypeConfig, 32>(a_layout, b_layout, argc, argv);
return run_gemm_mx_example_prec_type<TypeConfig, 32>(a_layout, b_layout, argc, argv);
}
else
{

View File

@@ -24,21 +24,21 @@ template <typename ADataType,
typename BScaleLayout,
typename CLayout,
uint32_t BlockScaleSize>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& a_m_k_scale_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& b_k_n_scale_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_AQ,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
float invoke_gemm_mx(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& a_m_k_scale_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& b_k_n_scale_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_AQ,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::GemmMXKernelArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
@@ -102,13 +102,13 @@ template <typename TypeConfig,
typename BLayout,
typename BScaleLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const AScaleLayout a_scale_layout = AScaleLayout{},
const BLayout b_layout = BLayout{},
const BScaleLayout b_scale_layout = BScaleLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
int run_gemm_mx_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const AScaleLayout a_scale_layout = AScaleLayout{},
const BLayout b_layout = BLayout{},
const BScaleLayout b_scale_layout = BScaleLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
@@ -224,33 +224,33 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType,
XPackedDataType,
BDataType,
XPackedDataType,
AccDataType,
CDataType,
ALayout,
AScaleLayout,
BLayout,
BScaleLayout,
CLayout,
BlockScaleSize>(a_m_k_dev_buf,
a_m_k_scale_dev_buf,
b_k_n_dev_buf,
b_k_n_scale_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
Scale_Stride_A,
stride_B,
Scale_Stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
invoke_gemm_mx<ADataType,
XPackedDataType,
BDataType,
XPackedDataType,
AccDataType,
CDataType,
ALayout,
AScaleLayout,
BLayout,
BScaleLayout,
CLayout,
BlockScaleSize>(a_m_k_dev_buf,
a_m_k_scale_dev_buf,
b_k_n_dev_buf,
b_k_n_scale_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
Scale_Stride_A,
stride_B,
Scale_Stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
@@ -261,13 +261,9 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
ck_tile::reference_gemm_mx<ADataType, BDataType, AScaleDataType, AccDataType, CDataType>(
a_m_k, a_m_k_scale, b_k_n, b_k_n_scale, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(

View File

@@ -176,192 +176,287 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
}
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename BDatatype,
typename ScaleDataType,
typename AccDataType,
typename CDataType,
typename ACCElementOp,
typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
CK_TILE_HOST void
reference_gemm_multiple_d(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_k_n,
const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
HostTensor<CDataType>& c_m_n,
const ACCElementOp& acc_element_op = {})
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm_mx(const HostTensor<ADataType>& a_m_k,
const HostTensor<ScaleDataType>& a_m_k_scale,
const HostTensor<BDatatype>& b_k_n,
const HostTensor<ScaleDataType>& b_k_n_scale,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mk_kn_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
for(std::size_t k = 0; k < K; ++k)
{
ADataType v_a = a_m_k(m, k);
BDataType v_b = b_k_n(k, n);
v_acc +=
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
}
const std::size_t ScaleBlockSize = K / a_m_k_scale.get_length(1);
CDataType v_c = 0;
if constexpr(DsDataType::size() == 0)
{
acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
}
else if constexpr(DsDataType::size() == 1)
{
acc_element_op(v_c,
ck_tile::type_convert<float>(v_acc),
ck_tile::type_convert<float>(ds_m_n[0](m, n)));
}
else if constexpr(DsDataType::size() == 2)
{
acc_element_op(v_c,
ck_tile::type_convert<float>(v_acc),
ck_tile::type_convert<float>(ds_m_n[0](m, n)),
ck_tile::type_convert<float>(ds_m_n[1](m, n)));
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
};
HostTensor<AccDataType> a_m_k_scaled({M, K}, {K, 1});
HostTensor<AccDataType> b_k_n_scaled({K, N}, {1, N});
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void naive_gemm_kernel(ADataType* A,
BDataType* B,
CDataType* C,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t strideA,
ck_tile::index_t strideB,
ck_tile::index_t strideC)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
if(row < M && col < N)
for(int m = 0; m < M; m++)
{
AccDataType acc = 0.0;
for(int k = 0; k < K; ++k)
for(int k = 0; k < K; k++)
{
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
: k * strideA + row;
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
if constexpr(std::is_same_v<ADataType, f4x2_pk_t>)
{
if(k % 2 == 1)
continue; // skip odd k
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
auto a_f4x2 = a_m_k(m, k);
auto a_scale = a_m_k_scale(m, k / ScaleBlockSize);
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
aut a_f4_lo =
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(Number<0>{}));
auto a_f4_hi =
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(Number<1>{}));
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
a_m_k_scaled(m, k) =
ck_tile::type_convert<AccDataType>((a_m_k(m, k))) *
ck_tile::type_convert<AccDataType>(a_m_k_scale(m, k / ScaleBlockSize));
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
}
acc += v_a * v_b;
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = ck_tile::type_convert<CDataType>(acc);
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
if constexpr(std::is_same_v<BDatatype, f4x2_pk_t>)
{
if(k % 2 == 1)
continue; // skip odd k
auto b_f4x2 = b_k_n(k, n);
auto b_scale = b_k_n_scale(k / ScaleBlockSize, n);
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
auto b_f4_lo =
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(Number<0>{}));
auto b_f4_hi =
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(Number<1>{}));
b_k_n_scaled(k, n) = b_f4_lo * b_scale;
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
}
else
{
b_k_n_scaled(k, n) =
ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
ck_tile::type_convert<AccDataType>(b_k_n_scale(k / ScaleBlockSize, n));
}
}
}
// call reference_gemm
reference_gemm<ADataType, AccDataType, BDatatype, CDataType>(
a_m_k_scaled, b_k_n_scaled, c_m_n);
}
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_gemm_gpu(ADataType* a_ptr,
BDataType* b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c)
{
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(
a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
return;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_batched_gemm_gpu(ADataType* a_ptr,
BDataType* b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t batch_stride_A,
index_t batch_stride_B,
index_t batch_stride_C,
index_t batch_count)
{
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ACCElementOp,
typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
CK_TILE_HOST void reference_gemm_multiple_d(
const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_k_n,
const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
HostTensor<CDataType>& c_m_n,
const ACCElementOp& acc_element_op = {})
{
ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mk_kn_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
for(std::size_t k = 0; k < K; ++k)
{
ADataType v_a = a_m_k(m, k);
BDataType v_b = b_k_n(k, n);
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b);
}
CDataType v_c = 0;
if constexpr(DsDataType::size() == 0)
{
acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
}
else if constexpr(DsDataType::size() == 1)
{
acc_element_op(v_c,
ck_tile::type_convert<float>(v_acc),
ck_tile::type_convert<float>(ds_m_n[0](m, n)));
}
else if constexpr(DsDataType::size() == 2)
{
acc_element_op(v_c,
ck_tile::type_convert<float>(v_acc),
ck_tile::type_convert<float>(ds_m_n[0](m, n)),
ck_tile::type_convert<float>(ds_m_n[1](m, n)));
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void naive_gemm_kernel(ADataType * A,
BDataType * B,
CDataType * C,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t strideA,
ck_tile::index_t strideB,
ck_tile::index_t strideC)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
if(row < M && col < N)
{
AccDataType acc = 0.0;
for(int k = 0; k < K; ++k)
{
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
: k * strideA + row;
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
}
acc += v_a * v_b;
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = ck_tile::type_convert<CDataType>(acc);
}
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_gemm_gpu(ADataType * a_ptr,
BDataType * b_ptr,
CDataType * c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c)
{
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
return;
}
return;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_batched_gemm_gpu(ADataType * a_ptr,
BDataType * b_ptr,
CDataType * c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t batch_stride_A,
index_t batch_stride_B,
index_t batch_stride_C,
index_t batch_count)
{
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
{
ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
naive_gemm_kernel<ADataType,
BDataType,
AccDataType,
CDataType,
LayoutA,
LayoutB,
LayoutC><<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
}
return;
}
} // namespace ck_tile