This commit is contained in:
Sami Remes
2025-12-11 16:21:37 +00:00
parent b2925ee207
commit 907d070ad6

View File

@@ -42,6 +42,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
if constexpr(!std::is_same_v<GridwiseGemm::UnquantizedADatatype, GridwiseGemm::ADataType>)
{
// Quantize A matrix: populate p_a_grid with quantized data from p_a_grid_unquantized
// and compute scales into p_a_scale_grid
GridwiseGemm::QuantizeA(karg.p_a_grid_unquantized, karg.p_a_grid, karg.p_a_scale_grid, karg.M, karg.K);
__syncthreads();
}
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
@@ -115,7 +123,8 @@ template <typename ALayout,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType,
typename LDSTypeB = BDataType>
typename LDSTypeB = BDataType,
typename UnquantizedADatatype = ADataType>
struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
{
using AScaleType = float;
@@ -627,7 +636,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
CElementwiseOperation c_element_op_,
UnquantizedADatatype* p_a_grid_unquantized_ = nullptr)
: Problem{M_,
N_,
K_,
@@ -646,7 +656,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
p_b_scale_grid{p_b_scale_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
c_element_op{c_element_op_}
c_element_op{c_element_op_},
p_a_grid_unquantized{p_a_grid_unquantized_}
{
// populate pointer, desc for Ds
@@ -669,6 +680,11 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CElementwiseOperation c_element_op;
// Quantize A before computing the GEMM if this is not a nullptr.
// In this case p_a_grid_ points to the empty/uninitialized data buffer for quantized A.
// And p_a_grid_unquantized points to the original unquantized A data.
const UnquantizedADatatype* p_a_grid_unquantized;
};
struct SplitKBatchOffset
@@ -1753,6 +1769,118 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
});
}
}
/**
* This code is adapted from AITER: aiter/csrc/kernels/quant_kernels.cu
*/
__device__ static void QuantizeA(
const UnquantizedADatatype* p_a_grid_unquantized,
const ADatatype* p_a_grid_const,
const AScaleType* p_a_scale_grid_const,
index_t M,
index_t K)
{
// HACK: Remove constness only here, to avoid changing other interfaces
ADatatype* p_a_grid = const_cast<ADatatype*>(p_a_grid_const);
AScaleDataType* p_a_scale_grid = const_cast<AScaleDataType*>(p_a_scale_grid_const);
static_assert(std::is_same_v<ADatatype, ck::fp8_t> ||
std::is_same_v<ADatatype, ck::bfloat8_t>,
"only fp8 and bfloat8 are supported!");
static_assert(ScaleBlockM == 1, "only per-token quantization is supported!");
// set variables used by the original aiter code from GridwiseGemm template params
constexpr index_t thread_data_size = 32;
constexpr index_t groupQuantBlockSize = 64;
constexpr index_t group_size = ScaleBlockK;
index_t ori_rows = M;
index_t ori_cols = K;
int num_thread_per_group = group_size / thread_data_size;
int64_t row_offset = blockIdx.x * groupQuantBlockSize;
int64_t groupId = (row_offset + threadIdx.x) / num_thread_per_group;
int32_t scaleN_pad = ori_cols / group_size;
int64_t x = groupId / scaleN_pad;
int32_t y = groupId % scaleN_pad;
if(x >= ori_rows)
return;
row_offset = x * ori_row_stride + y * group_size;
using vec_i = ck_tile::vec_t<DTYPE_I, thread_data_size>;
static constexpr int32_t vec_size_o = thread_data_size;
using vec_o = ck_tile::vec_t<DTYPE_O, vec_size_o>;
const float inverted_DTYPE_MAX =
std::is_same_v<DTYPE_O, ck_tile::fp4x2_t>
? 0.25
: (1. / ck::type_convert<float>(ck_tile::numeric<DTYPE_O>::max()));
static constexpr int32_t ooba_o = 4 / sizeof(DTYPE_O);
const int64_t oob_o = (ori_rows * ori_cols + ooba_o - 1) / ooba_o * ooba_o;
auto const* input_vecs = reinterpret_cast<vec_i const*>(input + row_offset);
vec_i thread_data = input_vecs[threadIdx.x % num_thread_per_group];
float absMax = 1e-10f;
for(size_t j = 0; j < thread_data_size; j++)
{
absMax = max(absMax, abs(ck::type_convert<float>(thread_data[j])));
}
absMax = multithread_reduce(absMax, hipcub::Max(), num_thread_per_group);
float inverted_scale = absMax * inverted_DTYPE_MAX;
row_offset = groupId * group_size + (threadIdx.x % num_thread_per_group) * vec_size_o;
if(threadIdx.x % num_thread_per_group == 0)
{
if constexpr(std::is_same_v<DTYPE_O, ck_tile::fp4x2_t>)
{
auto* tmp = reinterpret_cast<uint8_t*>(scale);
uint8_t exponent = (ck_tile::bit_cast<uint32_t>(inverted_scale) >> 23) & 0b11111111;
if(shuffle_scale)
{
groupId = fp4_scale_shuffle_id(scaleN_pad, x, y);
}
tmp[groupId] = exponent;
}
else
{
if(shuffle_scale)
{
groupId = y * ori_rows + x;
}
scale[groupId] = inverted_scale;
}
}
inverted_scale =
std::is_same_v<DTYPE_O, ck_tile::fp4x2_t> ? inverted_scale : 1.0f / inverted_scale;
using DTYPE_STORE = typename ck_tile::vector_traits<DTYPE_O>::scalar_type;
auto* out_ptr = reinterpret_cast<DTYPE_STORE*>(out);
auto buffer_o =
ck_tile::make_buffer_view<ck_tile::address_space_enum::global,
ck_tile::amd_buffer_coherence_enum::glc>(out_ptr, oob_o);
buffer_o.init_raw();
auto out_s =
ck_tile::vec_convert<DTYPE_O, DTYPE_I, thread_data_size>(thread_data, inverted_scale)
.template get_as<DTYPE_STORE>();
if constexpr(thread_data_size <= 16)
{
buffer_o.template set(row_offset, 0, true, out_s);
}
else
{
static constexpr int32_t o_step = std::is_same_v<DTYPE_O, ck_tile::fp4x2_t> ? 8 : 16;
assert(thread_data_size % 16 == 0);
using vecT = ck_tile::vec_t<DTYPE_STORE, o_step>;
auto vec = out_s.template get_as<vecT>();
static constexpr int32_t num_iter = thread_data_size / 16;
for(size_t j = 0; j < num_iter; j++)
{
buffer_o.template set(row_offset + j * o_step, 0, true, vec[j]);
}
}
}
};
} // namespace ck