diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 622eb74f67..6dc41e821b 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -177,138 +177,6 @@ CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, else if constexpr(std::is_same_v) { v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n))); -<<<<<<< HEAD - } - else - { - v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); - } - v_block_acc += v_a * v_b; - - // Apply group dequant scale - if((k + 1) % QuantGroupSize::kK == 0) - { - float a_scale = 0.f; - float b_scale = 0.f; - // A scale - index_t outer_dim = m / QuantGroupSize::kM; - index_t inner_dim = k / QuantGroupSize::kK; - if constexpr(std::is_same_v) - { - a_scale = a_q(outer_dim, inner_dim); - } - else if constexpr(std::is_same_v) - { - a_scale = fp8_to_float_raw(a_q(outer_dim, inner_dim)); - } - else if constexpr(std::is_same_v) - { - a_scale = bf8_to_float_raw(a_q(outer_dim, inner_dim)); - } - else - { - static_assert(false, "Unexpected Q datatype."); - } - // B scale - outer_dim = k / QuantGroupSize::kK; - inner_dim = n / QuantGroupSize::kN; - if constexpr(std::is_same_v) - { - b_scale = b_q(outer_dim, inner_dim); - } - else if constexpr(std::is_same_v) - { - b_scale = fp8_to_float_raw(b_q(outer_dim, inner_dim)); - } - else if constexpr(std::is_same_v) - { - b_scale = bf8_to_float_raw(b_q(outer_dim, inner_dim)); - } - else - { - static_assert(false, "Unexpected Q datatype."); - } - v_block_acc = v_block_acc * a_scale * b_scale; - v_acc += v_block_acc; - v_block_acc = 0; - } - } - - c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); - }; - - make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); -} - -template -CK_TILE_HOST void reference_gemm_abquant(const HostTensor& a_m_k, - const HostTensor& a_q, - const HostTensor& b_k_n, - const HostTensor& b_q, - HostTensor& c_m_n, - const AElementOp& a_element_op = {}, - const BElementOp& b_element_op = {}, - const ACCElementOp& acc_element_op = {}) -{ - const std::size_t M = a_m_k.get_length(0); - const std::size_t N = b_k_n.get_length(1); - const std::size_t K = a_m_k.get_length(1); - - auto f_mn = [&](auto m, auto n) { - AccDataType v_acc = 0, v_block_acc = 0; - - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v); - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v); - static_assert(std::is_same_v); - static_assert(std::is_same_v || - std::is_same_v); - for(std::size_t k = 0; k < K; ++k) - { - AccDataType v_a; - AccDataType v_b; - if constexpr(std::is_same_v) - { - const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - if(k % 2 == 1) - v_a = fp32_val.hi; - else - v_a = fp32_val.lo; - } - else - { - v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); - // printf("A %f m=%d k=%d\n", static_cast(v_a),static_cast(m) - // ,static_cast(k)); - } - - if constexpr(std::is_same_v) - { - const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); - const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); - if(k % 2 == 1) - v_b = fp32_val.hi; - else - v_b = fp32_val.lo; - } - else if constexpr(std::is_same_v) - { - v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n))); - // printf("B %f k=%d n=%d\n", static_cast(v_b),static_cast(k) - // ,static_cast(n)); -======= ->>>>>>> 198c21436 (Support A/B Quantization in Blockscale GEMM) } else {