mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)
[CK_TILE] MX GEMM non-preshuffled RCR layout ## Motivation Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled layouts using async pipeline. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b8def2c724
commit
8f27f65d44
@@ -249,6 +249,113 @@ struct BlockGemmARegBRegCRegV1
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B with MX scaling
|
||||
// ScaleATensor: [MIterPerWarp, KIterPerWarp] -> int32_t
|
||||
// ScaleBTensor: [NIterPerWarp, KIterPerWarp] -> int32_t
|
||||
template <typename CBlockTensor,
|
||||
typename ABlockTensor,
|
||||
typename BBlockTensor,
|
||||
typename ScaleATensor,
|
||||
typename ScaleBTensor>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor,
|
||||
const ScaleATensor& scale_a_tensor,
|
||||
const ScaleBTensor& scale_b_tensor) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
// check ABC-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"A distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"B distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"C distribution is wrong!");
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop with MX scaling:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// get A scale for this M-K tile using get_y_sliced_thread_data
|
||||
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
|
||||
sequence<kIter, mIter, 0>{}, sequence<1, 1, 1>{});
|
||||
const auto a_scale_e8m0 = scale_a_slice[number<0>{}];
|
||||
const int32_t a_scale = static_cast<int32_t>(a_scale_e8m0.get());
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// get B scale for this N-K tile using get_y_sliced_thread_data
|
||||
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
|
||||
sequence<kIter, nIter, 0>{}, sequence<1, 1, 1>{});
|
||||
const auto b_scale_e8m0 = scale_b_slice[number<0>{}];
|
||||
const int32_t b_scale = static_cast<int32_t>(b_scale_e8m0.get());
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::
|
||||
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM with MX scaling
|
||||
// Cast e8m0_t to int32_t, use OpSel=0 (least significant byte)
|
||||
constexpr index_t kOpSel = 0; // Always use OpSel=0
|
||||
WarpGemm{}.template operator()<kOpSel, kOpSel>(
|
||||
c_warp_tensor, a_warp_tensor, b_warp_tensor, a_scale, b_scale);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;
|
||||
|
||||
Reference in New Issue
Block a user