From 907d070ad6632e04c8a9be58153ebce85eca383b Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Thu, 11 Dec 2025 16:21:37 +0000 Subject: [PATCH] WIP --- ..._gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 134 +++++++++++++++++- 1 file changed, 131 insertions(+), 3 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 13061c7cd1..0b337d70d1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -42,6 +42,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + if constexpr(!std::is_same_v) + { + // Quantize A matrix: populate p_a_grid with quantized data from p_a_grid_unquantized + // and compute scales into p_a_scale_grid + GridwiseGemm::QuantizeA(karg.p_a_grid_unquantized, karg.p_a_grid, karg.p_a_scale_grid, karg.M, karg.K); + __syncthreads(); + } auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); @@ -115,7 +123,8 @@ template + typename LDSTypeB = BDataType, + typename UnquantizedADatatype = ADataType> struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 { using AScaleType = float; @@ -627,7 +636,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, - CElementwiseOperation c_element_op_) + CElementwiseOperation c_element_op_, + UnquantizedADatatype* p_a_grid_unquantized_ = nullptr) : Problem{M_, N_, K_, @@ -646,7 +656,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 p_b_scale_grid{p_b_scale_grid_}, a_element_op{a_element_op_}, b_element_op{b_element_op_}, - c_element_op{c_element_op_} + c_element_op{c_element_op_}, + p_a_grid_unquantized{p_a_grid_unquantized_} { // populate pointer, desc for Ds @@ -669,6 +680,11 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 const AElementwiseOperation a_element_op; const BElementwiseOperation b_element_op; const CElementwiseOperation c_element_op; + + // Quantize A before computing the GEMM if this is not a nullptr. + // In this case p_a_grid_ points to the empty/uninitialized data buffer for quantized A. + // And p_a_grid_unquantized points to the original unquantized A data. + const UnquantizedADatatype* p_a_grid_unquantized; }; struct SplitKBatchOffset @@ -1753,6 +1769,118 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 }); } } + + /** + * This code is adapted from AITER: aiter/csrc/kernels/quant_kernels.cu + */ + __device__ static void QuantizeA( + const UnquantizedADatatype* p_a_grid_unquantized, + const ADatatype* p_a_grid_const, + const AScaleType* p_a_scale_grid_const, + index_t M, + index_t K) + { + // HACK: Remove constness only here, to avoid changing other interfaces + ADatatype* p_a_grid = const_cast(p_a_grid_const); + AScaleDataType* p_a_scale_grid = const_cast(p_a_scale_grid_const); + + static_assert(std::is_same_v || + std::is_same_v, + "only fp8 and bfloat8 are supported!"); + static_assert(ScaleBlockM == 1, "only per-token quantization is supported!"); + + // set variables used by the original aiter code from GridwiseGemm template params + constexpr index_t thread_data_size = 32; + constexpr index_t groupQuantBlockSize = 64; + constexpr index_t group_size = ScaleBlockK; + index_t ori_rows = M; + index_t ori_cols = K; + + int num_thread_per_group = group_size / thread_data_size; + int64_t row_offset = blockIdx.x * groupQuantBlockSize; + int64_t groupId = (row_offset + threadIdx.x) / num_thread_per_group; + int32_t scaleN_pad = ori_cols / group_size; + int64_t x = groupId / scaleN_pad; + int32_t y = groupId % scaleN_pad; + + if(x >= ori_rows) + return; + + row_offset = x * ori_row_stride + y * group_size; + using vec_i = ck_tile::vec_t; + static constexpr int32_t vec_size_o = thread_data_size; + using vec_o = ck_tile::vec_t; + const float inverted_DTYPE_MAX = + std::is_same_v + ? 0.25 + : (1. / ck::type_convert(ck_tile::numeric::max())); + + static constexpr int32_t ooba_o = 4 / sizeof(DTYPE_O); + const int64_t oob_o = (ori_rows * ori_cols + ooba_o - 1) / ooba_o * ooba_o; + + auto const* input_vecs = reinterpret_cast(input + row_offset); + vec_i thread_data = input_vecs[threadIdx.x % num_thread_per_group]; + float absMax = 1e-10f; + for(size_t j = 0; j < thread_data_size; j++) + { + absMax = max(absMax, abs(ck::type_convert(thread_data[j]))); + } + absMax = multithread_reduce(absMax, hipcub::Max(), num_thread_per_group); + + float inverted_scale = absMax * inverted_DTYPE_MAX; + row_offset = groupId * group_size + (threadIdx.x % num_thread_per_group) * vec_size_o; + if(threadIdx.x % num_thread_per_group == 0) + { + if constexpr(std::is_same_v) + { + auto* tmp = reinterpret_cast(scale); + uint8_t exponent = (ck_tile::bit_cast(inverted_scale) >> 23) & 0b11111111; + if(shuffle_scale) + { + groupId = fp4_scale_shuffle_id(scaleN_pad, x, y); + } + tmp[groupId] = exponent; + } + else + { + if(shuffle_scale) + { + groupId = y * ori_rows + x; + } + scale[groupId] = inverted_scale; + } + } + inverted_scale = + std::is_same_v ? inverted_scale : 1.0f / inverted_scale; + + using DTYPE_STORE = typename ck_tile::vector_traits::scalar_type; + auto* out_ptr = reinterpret_cast(out); + auto buffer_o = + ck_tile::make_buffer_view(out_ptr, oob_o); + buffer_o.init_raw(); + + auto out_s = + ck_tile::vec_convert(thread_data, inverted_scale) + .template get_as(); + if constexpr(thread_data_size <= 16) + { + buffer_o.template set(row_offset, 0, true, out_s); + } + else + { + static constexpr int32_t o_step = std::is_same_v ? 8 : 16; + assert(thread_data_size % 16 == 0); + using vecT = ck_tile::vec_t; + auto vec = out_s.template get_as(); + static constexpr int32_t num_iter = thread_data_size / 16; + + for(size_t j = 0; j < num_iter; j++) + { + buffer_o.template set(row_offset + j * o_step, 0, true, vec[j]); + } + } + } }; } // namespace ck