mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
WIP
This commit is contained in:
@@ -42,6 +42,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
if constexpr(!std::is_same_v<GridwiseGemm::UnquantizedADatatype, GridwiseGemm::ADataType>)
|
||||
{
|
||||
// Quantize A matrix: populate p_a_grid with quantized data from p_a_grid_unquantized
|
||||
// and compute scales into p_a_scale_grid
|
||||
GridwiseGemm::QuantizeA(karg.p_a_grid_unquantized, karg.p_a_grid, karg.p_a_scale_grid, karg.M, karg.K);
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
|
||||
@@ -115,7 +123,8 @@ template <typename ALayout,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ADataType,
|
||||
typename LDSTypeB = BDataType>
|
||||
typename LDSTypeB = BDataType,
|
||||
typename UnquantizedADatatype = ADataType>
|
||||
struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
{
|
||||
using AScaleType = float;
|
||||
@@ -627,7 +636,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
CElementwiseOperation c_element_op_,
|
||||
UnquantizedADatatype* p_a_grid_unquantized_ = nullptr)
|
||||
: Problem{M_,
|
||||
N_,
|
||||
K_,
|
||||
@@ -646,7 +656,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
p_b_scale_grid{p_b_scale_grid_},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
c_element_op{c_element_op_}
|
||||
c_element_op{c_element_op_},
|
||||
p_a_grid_unquantized{p_a_grid_unquantized_}
|
||||
{
|
||||
|
||||
// populate pointer, desc for Ds
|
||||
@@ -669,6 +680,11 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
const AElementwiseOperation a_element_op;
|
||||
const BElementwiseOperation b_element_op;
|
||||
const CElementwiseOperation c_element_op;
|
||||
|
||||
// Quantize A before computing the GEMM if this is not a nullptr.
|
||||
// In this case p_a_grid_ points to the empty/uninitialized data buffer for quantized A.
|
||||
// And p_a_grid_unquantized points to the original unquantized A data.
|
||||
const UnquantizedADatatype* p_a_grid_unquantized;
|
||||
};
|
||||
|
||||
struct SplitKBatchOffset
|
||||
@@ -1753,6 +1769,118 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This code is adapted from AITER: aiter/csrc/kernels/quant_kernels.cu
|
||||
*/
|
||||
__device__ static void QuantizeA(
|
||||
const UnquantizedADatatype* p_a_grid_unquantized,
|
||||
const ADatatype* p_a_grid_const,
|
||||
const AScaleType* p_a_scale_grid_const,
|
||||
index_t M,
|
||||
index_t K)
|
||||
{
|
||||
// HACK: Remove constness only here, to avoid changing other interfaces
|
||||
ADatatype* p_a_grid = const_cast<ADatatype*>(p_a_grid_const);
|
||||
AScaleDataType* p_a_scale_grid = const_cast<AScaleDataType*>(p_a_scale_grid_const);
|
||||
|
||||
static_assert(std::is_same_v<ADatatype, ck::fp8_t> ||
|
||||
std::is_same_v<ADatatype, ck::bfloat8_t>,
|
||||
"only fp8 and bfloat8 are supported!");
|
||||
static_assert(ScaleBlockM == 1, "only per-token quantization is supported!");
|
||||
|
||||
// set variables used by the original aiter code from GridwiseGemm template params
|
||||
constexpr index_t thread_data_size = 32;
|
||||
constexpr index_t groupQuantBlockSize = 64;
|
||||
constexpr index_t group_size = ScaleBlockK;
|
||||
index_t ori_rows = M;
|
||||
index_t ori_cols = K;
|
||||
|
||||
int num_thread_per_group = group_size / thread_data_size;
|
||||
int64_t row_offset = blockIdx.x * groupQuantBlockSize;
|
||||
int64_t groupId = (row_offset + threadIdx.x) / num_thread_per_group;
|
||||
int32_t scaleN_pad = ori_cols / group_size;
|
||||
int64_t x = groupId / scaleN_pad;
|
||||
int32_t y = groupId % scaleN_pad;
|
||||
|
||||
if(x >= ori_rows)
|
||||
return;
|
||||
|
||||
row_offset = x * ori_row_stride + y * group_size;
|
||||
using vec_i = ck_tile::vec_t<DTYPE_I, thread_data_size>;
|
||||
static constexpr int32_t vec_size_o = thread_data_size;
|
||||
using vec_o = ck_tile::vec_t<DTYPE_O, vec_size_o>;
|
||||
const float inverted_DTYPE_MAX =
|
||||
std::is_same_v<DTYPE_O, ck_tile::fp4x2_t>
|
||||
? 0.25
|
||||
: (1. / ck::type_convert<float>(ck_tile::numeric<DTYPE_O>::max()));
|
||||
|
||||
static constexpr int32_t ooba_o = 4 / sizeof(DTYPE_O);
|
||||
const int64_t oob_o = (ori_rows * ori_cols + ooba_o - 1) / ooba_o * ooba_o;
|
||||
|
||||
auto const* input_vecs = reinterpret_cast<vec_i const*>(input + row_offset);
|
||||
vec_i thread_data = input_vecs[threadIdx.x % num_thread_per_group];
|
||||
float absMax = 1e-10f;
|
||||
for(size_t j = 0; j < thread_data_size; j++)
|
||||
{
|
||||
absMax = max(absMax, abs(ck::type_convert<float>(thread_data[j])));
|
||||
}
|
||||
absMax = multithread_reduce(absMax, hipcub::Max(), num_thread_per_group);
|
||||
|
||||
float inverted_scale = absMax * inverted_DTYPE_MAX;
|
||||
row_offset = groupId * group_size + (threadIdx.x % num_thread_per_group) * vec_size_o;
|
||||
if(threadIdx.x % num_thread_per_group == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<DTYPE_O, ck_tile::fp4x2_t>)
|
||||
{
|
||||
auto* tmp = reinterpret_cast<uint8_t*>(scale);
|
||||
uint8_t exponent = (ck_tile::bit_cast<uint32_t>(inverted_scale) >> 23) & 0b11111111;
|
||||
if(shuffle_scale)
|
||||
{
|
||||
groupId = fp4_scale_shuffle_id(scaleN_pad, x, y);
|
||||
}
|
||||
tmp[groupId] = exponent;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(shuffle_scale)
|
||||
{
|
||||
groupId = y * ori_rows + x;
|
||||
}
|
||||
scale[groupId] = inverted_scale;
|
||||
}
|
||||
}
|
||||
inverted_scale =
|
||||
std::is_same_v<DTYPE_O, ck_tile::fp4x2_t> ? inverted_scale : 1.0f / inverted_scale;
|
||||
|
||||
using DTYPE_STORE = typename ck_tile::vector_traits<DTYPE_O>::scalar_type;
|
||||
auto* out_ptr = reinterpret_cast<DTYPE_STORE*>(out);
|
||||
auto buffer_o =
|
||||
ck_tile::make_buffer_view<ck_tile::address_space_enum::global,
|
||||
ck_tile::amd_buffer_coherence_enum::glc>(out_ptr, oob_o);
|
||||
buffer_o.init_raw();
|
||||
|
||||
auto out_s =
|
||||
ck_tile::vec_convert<DTYPE_O, DTYPE_I, thread_data_size>(thread_data, inverted_scale)
|
||||
.template get_as<DTYPE_STORE>();
|
||||
if constexpr(thread_data_size <= 16)
|
||||
{
|
||||
buffer_o.template set(row_offset, 0, true, out_s);
|
||||
}
|
||||
else
|
||||
{
|
||||
static constexpr int32_t o_step = std::is_same_v<DTYPE_O, ck_tile::fp4x2_t> ? 8 : 16;
|
||||
assert(thread_data_size % 16 == 0);
|
||||
using vecT = ck_tile::vec_t<DTYPE_STORE, o_step>;
|
||||
auto vec = out_s.template get_as<vecT>();
|
||||
static constexpr int32_t num_iter = thread_data_size / 16;
|
||||
|
||||
for(size_t j = 0; j < num_iter; j++)
|
||||
{
|
||||
buffer_o.template set(row_offset + j * o_step, 0, true, vec[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user