diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index 35b6025594..b892a22777 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -249,6 +249,106 @@ struct BlockGemmARegBRegCRegV1 }); } + // C += A * B with MX scaling + // ScaleATensor: [MIterPerWarp, KIterPerWarp] -> int32_t + // ScaleBTensor: [NIterPerWarp, KIterPerWarp] -> int32_t + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensor& a_block_tensor, + const BBlockTensor& b_block_tensor, + const ScaleATensor& scale_a_tensor, + const ScaleBTensor& scale_b_tensor) const + { + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + // check ABC-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "A distribution is wrong!"); + static_assert( + std::is_same_v, + remove_cvref_t>, + "B distribution is wrong!"); + static_assert( + std::is_same_v, + remove_cvref_t>, + "C distribution is wrong!"); + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop with MX scaling: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // get A scale for this M-K tile + const int32_t a_scale = scale_a_tensor(mIter, kIter); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // get B scale for this N-K tile + const int32_t b_scale = scale_b_tensor(nIter, kIter); + + // read C warp tensor from C block tensor + using c_iter_idx = std:: + conditional_t, sequence>; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM with MX scaling + // opsel is kIter for both A and B (selecting which packed element group) + WarpGemm{}.template operator()( + c_warp_tensor, a_warp_tensor, a_scale, b_warp_tensor, b_scale); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { using c_distr_ys_major = std::conditional_t, sequence<1, 2>>;