From 1afa2b61e3ef30c2d3e67428422726d500ca7c1a Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Tue, 5 Aug 2025 05:22:59 -0500 Subject: [PATCH] update reference gemm mx --- .../gemm_mx_fp4_basic.cpp | 9 +- .../run_gemm_mx_example.inc | 108 +++-- .../ck_tile/host/reference/reference_gemm.hpp | 421 +++++++++++------- 3 files changed, 316 insertions(+), 222 deletions(-) diff --git a/example/ck_tile/39_gemm_microscaling/gemm_mx_fp4_basic.cpp b/example/ck_tile/39_gemm_microscaling/gemm_mx_fp4_basic.cpp index 626dedea0c..9a1885e428 100644 --- a/example/ck_tile/39_gemm_microscaling/gemm_mx_fp4_basic.cpp +++ b/example/ck_tile/39_gemm_microscaling/gemm_mx_fp4_basic.cpp @@ -154,7 +154,10 @@ float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::strea #include "run_gemm_mx_example.inc" template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +int run_gemm_mx_example_prec_type(std::string a_layout, + std::string b_layout, + int argc, + char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -163,7 +166,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a { if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return run_gemm_mx_example_with_layouts( argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); } else @@ -196,7 +199,7 @@ int run_gemm_mx_example(int argc, char* argv[]) ck_tile::e8m0_bexp_t, int32_t, ck_tile::half_t>{}); - return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + return run_gemm_mx_example_prec_type(a_layout, b_layout, argc, argv); } else { diff --git a/example/ck_tile/39_gemm_microscaling/run_gemm_mx_example.inc b/example/ck_tile/39_gemm_microscaling/run_gemm_mx_example.inc index d9822e7560..030425734c 100644 --- a/example/ck_tile/39_gemm_microscaling/run_gemm_mx_example.inc +++ b/example/ck_tile/39_gemm_microscaling/run_gemm_mx_example.inc @@ -24,21 +24,21 @@ template -float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& a_m_k_scale_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::DeviceMem& b_k_n_scale_dev_buf, - ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_AQ, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C, - ck_tile::index_t kbatch, - int n_warmup, - int n_repeat) +float invoke_gemm_mx(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& a_m_k_scale_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& b_k_n_scale_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_AQ, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) { ck_tile::GemmMXKernelArgs args; args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); @@ -102,13 +102,13 @@ template -int run_gemm_example_with_layouts(int argc, - char* argv[], - const ALayout a_layout = ALayout{}, - const AScaleLayout a_scale_layout = AScaleLayout{}, - const BLayout b_layout = BLayout{}, - const BScaleLayout b_scale_layout = BScaleLayout{}, - [[maybe_unused]] const CLayout c_layout = CLayout{}) +int run_gemm_mx_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const AScaleLayout a_scale_layout = AScaleLayout{}, + const BLayout b_layout = BLayout{}, + const BScaleLayout b_scale_layout = BScaleLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) { auto [result, arg_parser] = create_args(argc, argv); if(!result) @@ -224,33 +224,33 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm(a_m_k_dev_buf, - a_m_k_scale_dev_buf, - b_k_n_dev_buf, - b_k_n_scale_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - Scale_Stride_A, - stride_B, - Scale_Stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat); + invoke_gemm_mx(a_m_k_dev_buf, + a_m_k_scale_dev_buf, + b_k_n_dev_buf, + b_k_n_scale_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + Scale_Stride_A, + stride_B, + Scale_Stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; @@ -261,13 +261,9 @@ int run_gemm_example_with_layouts(int argc, ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); - ck_tile::reference_gemm_quant(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref); + ck_tile::reference_gemm_mx( + a_m_k, a_m_k_scale, b_k_n, b_k_n_scale, c_m_n_host_ref); + const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); const auto rtol_atol = calculate_rtol_atol( diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 70ca44170e..4aaadd786f 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -176,192 +176,287 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, } template >> -CK_TILE_HOST void -reference_gemm_multiple_d(const HostTensor& a_m_k, - const HostTensor& b_k_n, - const std::array, DsDataType::size()>& ds_m_n, - HostTensor& c_m_n, - const ACCElementOp& acc_element_op = {}) + typename AElementOp = ck_tile::identity, + typename BElementOp = ck_tile::identity, + typename ACCElementOp = ck_tile::identity> +CK_TILE_HOST void reference_gemm_mx(const HostTensor& a_m_k, + const HostTensor& a_m_k_scale, + const HostTensor& b_k_n, + const HostTensor& b_k_n_scale, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const std::size_t M = a_m_k.get_length(0); const std::size_t N = b_k_n.get_length(1); const std::size_t K = a_m_k.get_length(1); - auto f_mk_kn_mn = [&](auto m, auto n) { - AccDataType v_acc = 0; - for(std::size_t k = 0; k < K; ++k) - { - ADataType v_a = a_m_k(m, k); - BDataType v_b = b_k_n(k, n); - v_acc += - ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); - } + const std::size_t ScaleBlockSize = K / a_m_k_scale.get_length(1); - CDataType v_c = 0; - if constexpr(DsDataType::size() == 0) - { - acc_element_op(v_c, ck_tile::type_convert(v_acc)); - } - else if constexpr(DsDataType::size() == 1) - { - acc_element_op(v_c, - ck_tile::type_convert(v_acc), - ck_tile::type_convert(ds_m_n[0](m, n))); - } - else if constexpr(DsDataType::size() == 2) - { - acc_element_op(v_c, - ck_tile::type_convert(v_acc), - ck_tile::type_convert(ds_m_n[0](m, n)), - ck_tile::type_convert(ds_m_n[1](m, n))); - } - c_m_n(m, n) = ck_tile::type_convert(v_c); - }; + HostTensor a_m_k_scaled({M, K}, {K, 1}); + HostTensor b_k_n_scaled({K, N}, {1, N}); - make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency()); -} - -template -__global__ void naive_gemm_kernel(ADataType* A, - BDataType* B, - CDataType* C, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t strideA, - ck_tile::index_t strideB, - ck_tile::index_t strideC) -{ - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int row = idx / N; // Compute row index - int col = idx % N; // Compute column index - - if(row < M && col < N) + for(int m = 0; m < M; m++) { - AccDataType acc = 0.0; - for(int k = 0; k < K; ++k) + for(int k = 0; k < K; k++) { - constexpr index_t packed_size_a = ck_tile::numeric_traits::PackedSize; - constexpr index_t packed_size_b = ck_tile::numeric_traits::PackedSize; - // Adjust indexing based on matrix layout - int a_index = (std::is_same_v) - ? row * strideA + k - : k * strideA + row; - int b_index = (std::is_same_v) - ? col * strideB + k - : k * strideB + col; + if constexpr(std::is_same_v) + { + if(k % 2 == 1) + continue; // skip odd k - AccDataType v_a; - AccDataType v_b; - if constexpr(std::is_same_v) - { - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]); - if(k % 2 == 1) - v_a = fp32_val.hi; - else - v_a = fp32_val.lo; + auto a_f4x2 = a_m_k(m, k); + auto a_scale = a_m_k_scale(m, k / ScaleBlockSize); + // auto f4_lo = ck_tile::type_convert(f4x2)[0]; + // auto f4_hi = ck_tile::type_convert(f4x2)[1]; + aut a_f4_lo = + ck_tile::type_convert(a_f4x2.template unpack<>(Number<0>{})); + auto a_f4_hi = + ck_tile::type_convert(a_f4x2.template unpack<>(Number<1>{})); + + a_m_k_scaled(m, k) = a_f4_lo * a_scale; + a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale; } else { - v_a = ck_tile::type_convert(A[a_index]); + a_m_k_scaled(m, k) = + ck_tile::type_convert((a_m_k(m, k))) * + ck_tile::type_convert(a_m_k_scale(m, k / ScaleBlockSize)); } - if constexpr(std::is_same_v) - { - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]); - if(k % 2 == 1) - v_b = fp32_val.hi; - else - v_b = fp32_val.lo; - } - else - { - v_b = ck_tile::type_convert(B[b_index]); - } - acc += v_a * v_b; } - int c_index = (std::is_same_v) - ? row * strideC + col - : col * strideC + row; - C[c_index] = ck_tile::type_convert(acc); + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + if constexpr(std::is_same_v) + { + if(k % 2 == 1) + continue; // skip odd k + + auto b_f4x2 = b_k_n(k, n); + auto b_scale = b_k_n_scale(k / ScaleBlockSize, n); + // auto f4_lo = ck_tile::type_convert(f4x2)[0]; + // auto f4_hi = ck_tile::type_convert(f4x2)[1]; + auto b_f4_lo = + ck_tile::type_convert(b_f4x2.template unpack<>(Number<0>{})); + auto b_f4_hi = + ck_tile::type_convert(b_f4x2.template unpack<>(Number<1>{})); + + b_k_n_scaled(k, n) = b_f4_lo * b_scale; + b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale; + } + else + { + b_k_n_scaled(k, n) = + ck_tile::type_convert((b_k_n(k, n))) * + ck_tile::type_convert(b_k_n_scale(k / ScaleBlockSize, n)); + } + } + } + + // call reference_gemm + reference_gemm( + a_m_k_scaled, b_k_n_scaled, c_m_n); } -} -template -void reference_gemm_gpu(ADataType* a_ptr, - BDataType* b_ptr, - CDataType* c_ptr, - index_t M, - index_t N, - index_t K, - index_t stride_a, - index_t stride_b, - index_t stride_c) -{ - int totalElements = M * N; - int numThreadsPerBlock = 256; // Common choice for threads per block - int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; - - naive_gemm_kernel - <<>>( - a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c); - - return; -} - -template -void reference_batched_gemm_gpu(ADataType* a_ptr, - BDataType* b_ptr, - CDataType* c_ptr, - index_t M, - index_t N, - index_t K, - index_t stride_a, - index_t stride_b, - index_t stride_c, - index_t batch_stride_A, - index_t batch_stride_B, - index_t batch_stride_C, - index_t batch_count) -{ - int totalElements = M * N; - int numThreadsPerBlock = 256; // Common choice for threads per block - int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; - - for(index_t batch_id = 0; batch_id < batch_count; ++batch_id) + template >> + CK_TILE_HOST void reference_gemm_multiple_d( + const HostTensor& a_m_k, + const HostTensor& b_k_n, + const std::array, DsDataType::size()>& ds_m_n, + HostTensor& c_m_n, + const ACCElementOp& acc_element_op = {}) { - ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A; - BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B; - CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C; + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto f_mk_kn_mn = [&](auto m, auto n) { + AccDataType v_acc = 0; + for(std::size_t k = 0; k < K; ++k) + { + ADataType v_a = a_m_k(m, k); + BDataType v_b = b_k_n(k, n); + v_acc += ck_tile::type_convert(v_a) * + ck_tile::type_convert(v_b); + } + + CDataType v_c = 0; + if constexpr(DsDataType::size() == 0) + { + acc_element_op(v_c, ck_tile::type_convert(v_acc)); + } + else if constexpr(DsDataType::size() == 1) + { + acc_element_op(v_c, + ck_tile::type_convert(v_acc), + ck_tile::type_convert(ds_m_n[0](m, n))); + } + else if constexpr(DsDataType::size() == 2) + { + acc_element_op(v_c, + ck_tile::type_convert(v_acc), + ck_tile::type_convert(ds_m_n[0](m, n)), + ck_tile::type_convert(ds_m_n[1](m, n))); + } + c_m_n(m, n) = ck_tile::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency()); + } + + template + __global__ void naive_gemm_kernel(ADataType * A, + BDataType * B, + CDataType * C, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t strideA, + ck_tile::index_t strideB, + ck_tile::index_t strideC) + { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int row = idx / N; // Compute row index + int col = idx % N; // Compute column index + + if(row < M && col < N) + { + AccDataType acc = 0.0; + for(int k = 0; k < K; ++k) + { + constexpr index_t packed_size_a = ck_tile::numeric_traits::PackedSize; + constexpr index_t packed_size_b = ck_tile::numeric_traits::PackedSize; + // Adjust indexing based on matrix layout + int a_index = (std::is_same_v) + ? row * strideA + k + : k * strideA + row; + int b_index = (std::is_same_v) + ? col * strideB + k + : k * strideB + col; + + AccDataType v_a; + AccDataType v_b; + if constexpr(std::is_same_v) + { + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]); + if(k % 2 == 1) + v_a = fp32_val.hi; + else + v_a = fp32_val.lo; + } + else + { + v_a = ck_tile::type_convert(A[a_index]); + } + if constexpr(std::is_same_v) + { + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]); + if(k % 2 == 1) + v_b = fp32_val.hi; + else + v_b = fp32_val.lo; + } + else + { + v_b = ck_tile::type_convert(B[b_index]); + } + acc += v_a * v_b; + } + + int c_index = (std::is_same_v) + ? row * strideC + col + : col * strideC + row; + C[c_index] = ck_tile::type_convert(acc); + } + } + + template + void reference_gemm_gpu(ADataType * a_ptr, + BDataType * b_ptr, + CDataType * c_ptr, + index_t M, + index_t N, + index_t K, + index_t stride_a, + index_t stride_b, + index_t stride_c) + { + int totalElements = M * N; + int numThreadsPerBlock = 256; // Common choice for threads per block + int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; + naive_gemm_kernel <<>>( - d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c); + a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c); + + return; } - return; -} + template + void reference_batched_gemm_gpu(ADataType * a_ptr, + BDataType * b_ptr, + CDataType * c_ptr, + index_t M, + index_t N, + index_t K, + index_t stride_a, + index_t stride_b, + index_t stride_c, + index_t batch_stride_A, + index_t batch_stride_B, + index_t batch_stride_C, + index_t batch_count) + { + int totalElements = M * N; + int numThreadsPerBlock = 256; // Common choice for threads per block + int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; + + for(index_t batch_id = 0; batch_id < batch_count; ++batch_id) + { + ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A; + BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B; + CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C; + naive_gemm_kernel<<>>( + d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c); + } + + return; + } } // namespace ck_tile