This commit is contained in:
Sami Remes
2026-01-14 12:07:26 -05:00
parent 5d4e07e095
commit f6f9931541
5 changed files with 181 additions and 122 deletions

View File

@@ -180,92 +180,115 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
// Create C block window following UniversalGemmKernel pattern
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename ScaleM, typename ScaleN>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const std::array<const ADataType*, NumATensor>& as_ptr,
const std::array<const BDataType*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
const KernelArgs<ScaleM, ScaleN>& kargs,
const SplitKBatchOffset& splitk_batch_offset)
CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr,
const KernelArgs<ScaleM, ScaleN>& kargs,
const index_t i_m,
const index_t i_n)
{
// Get tensor views from the UniversalGemmKernel
const auto& gemm_tensor_views_tuple =
Underlying::template MakeGemmTensorViews<DstInMemOp>(
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
// Create tensor view for E/C tensor
constexpr index_t vector_size = EpiloguePipeline::GetVectorSizeC();
const auto& e_tensor_view = [&]() -> auto {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_E, 1),
number<vector_size>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_E),
number<1>{},
number<vector_size>{});
}
}();
// Create padded view
const auto& e_pad_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, false>{});
}
else
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, false>{});
}
}();
// Create block window
auto c_block_window = make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return c_block_window;
}
// Create scale A block windows following the pattern of MakeABlockWindows
template <typename ScaleM, typename ScaleN>
CK_TILE_DEVICE static auto
MakeScaleABlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs, const index_t block_idx_m)
{
auto scale_a = kargs.scale_m_ptr;
auto scale_b = kargs.scale_n_ptr;
static_assert(ScaleM::GranularityK == ScaleN::GranularityK, "M and N scales must have same K granularity!");
static constexpr int BlockScaleSize = ScaleM::GranularityK;
// With 1D K-only packing: each int32 contains KXdlPack consecutive e8m0 values
// Scale A layout: [M, K/BlockScaleSize/KXdlPack] where each element is int32
// Scale B layout: [N, K/BlockScaleSize/KXdlPack] where each element is int32
const auto&& scale_k_packed = kargs.K / BlockScaleSize / KXdlPack;
static constexpr int BlockScaleSize = ScaleM::GranularityK;
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPack;
// A scale tensor view - simple 2D layout [M, K/32/4]
const auto& scale_a_tensor_view = [&]() {
const auto scale_a_desc = make_naive_tensor_descriptor_packed(
make_tuple(kargs.M, scale_k_packed));
// A scale tensor view - simple 2D layout [M, K/BlockScaleSize/KXdlPack]
const auto scale_a_desc = make_naive_tensor_descriptor_packed(
make_tuple(kargs.M, scale_k_packed));
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
}();
const auto scale_a_tensor_view = make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
// B scale tensor view - layout [K/32/4, N] to match reference
// Reference provides scale_b(k/32, n), so it's [K/32, N] in e8m0
// With KXdlPack=4, we pack 4 e8m0 into 1 int32, so it's [K/32/4, N]
const auto& scale_b_tensor_view = [&]() {
const auto scale_b_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_k_packed, kargs.N));
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
}();
return concat_tuple(gemm_tensor_views_tuple, make_tuple(scale_a_tensor_view, scale_b_tensor_view));
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& padded_views = Underlying::template MakeGemmPadViews<TensorView>(views);
return make_tuple(
padded_views.at(I0), padded_views.at(I1), padded_views.at(I2), padded_views.at(I3), views.at(I4), views.at(I5));
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& tile_windows = Underlying::template MakeGemmTileWindows<PadView>(views, i_m, i_n);
static constexpr int BlockScaleSize = 32;
// With 1D K-only packing: MXdlPack=1, NXdlPack=1, KXdlPack=4
// Each int32 contains KXdlPack consecutive e8m0 scales
// Create block window for scale A
auto scale_a_block_window = make_tile_window(
views.at(I4),
scale_a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPack>{}),
{i_m, 0});
{block_idx_m, 0});
// Scale B window matches [K/32/4, N] layout from reference
return scale_a_block_window;
}
// Create scale B block windows following the pattern of MakeBBlockWindows
template <typename ScaleM, typename ScaleN>
CK_TILE_DEVICE static auto
MakeScaleBBlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs, const index_t block_idx_n)
{
auto scale_b = kargs.scale_n_ptr;
static constexpr int BlockScaleSize = ScaleN::GranularityK;
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPack;
// B scale tensor view - layout [K/BlockScaleSize/KXdlPack, N]
const auto scale_b_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_k_packed, kargs.N));
const auto scale_b_tensor_view = make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
// Create block window for scale B
auto scale_b_block_window = make_tile_window(
views.at(I5),
scale_b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPack>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
{0, block_idx_n});
return make_tuple(tile_windows.at(I0),
tile_windows.at(I1),
tile_windows.at(I2),
tile_windows.at(I3),
scale_a_block_window,
scale_b_block_window);
return scale_b_block_window;
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
@@ -281,22 +304,19 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Create block windows directly, following the new pattern from UniversalGemmKernel
const auto& a_block_window =
Underlying::MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_block_window =
Underlying::MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& d_block_window = Underlying::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
// Create scale block windows using our new functions
const auto& scale_a_block_window = MakeScaleABlockWindows(kargs, block_idx_m);
const auto& scale_b_block_window = MakeScaleBBlockWindows(kargs, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& scale_a_block_window = gemm_tile_windows.at(I4);
const auto& scale_b_block_window = gemm_tile_windows.at(I5);
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
@@ -310,8 +330,9 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
smem_ptr_ping,
smem_ptr_pong);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
// Run Epilogue Pipeline - create C block window directly
auto c_block_window =
MakeCBlockWindows<EpiloguePipeline::MemoryOperation>(e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
}
@@ -338,7 +359,8 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// Cast to base class for SplitKBatchOffset construction
const SplitKBatchOffset splitk_batch_offset(static_cast<const typename Underlying::KernelArgs&>(kargs));
// options
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);

View File

@@ -5,19 +5,22 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
// MX scaling support with OpSel
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompAsync
struct BaseMXGemmPipelineAgBgCrCompAsync
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
@@ -87,15 +90,16 @@ struct BaseGemmPipelineAgBgCrCompAsync
};
/**
* @brief Compute optimized pipeline version async; which is based on V4.
* @brief MX GEMM compute optimized pipeline version async; which is based on V4.
*
* This pipeline introduces asynchronous load from global memory to LDS,
* skipping the intermediate loading into pipeline registers.
* Supports MX scaling with e8m0 packed values and OpSel.
*/
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Problem>
template <typename Problem, typename Policy = MXGemmPipelineAgBgCrCompAsyncDefaultPolicy>
struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompAsync<Problem>;
using Base = BaseMXGemmPipelineAgBgCrCompAsync<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
@@ -117,6 +121,11 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
// MX scaling packing constants
static constexpr int MXdlPack = Policy::MXdlPack;
static constexpr int NXdlPack = Policy::NXdlPack;
static constexpr int KXdlPack = Policy::KXdlPack;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
@@ -317,7 +326,6 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
constexpr index_t KPerXdl = WarpTile::at(I2{});
constexpr index_t ScaleBlockSize = 32;
constexpr index_t KXdlPack = Policy::KXdlPack;
// Scale A DRAM Window: [MWarp * MPerXdl, kKPerBlock / 32 / KXdlPack]
auto scale_a_dram_window = make_tile_window(
@@ -520,19 +528,28 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) {
constexpr auto OpSelB = kScaleInPack;
// Extract A/B values for this iteration
auto a_val = a_block_tile.get_y_sliced_thread_data(
// Extract A/B values for this iteration - create warp tensors
typename WarpGemm::AWarpTensor a_warp_tensor{};
const auto a_thread_data = a_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<m_iter, k_iter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
auto b_val = b_block_tile.get_y_sliced_thread_data(
static_for<0, a_warp_tensor.get_thread_buffer_size(), 1>{}([&](auto i) {
a_warp_tensor.get_thread_buffer()(i) = a_thread_data[i];
});
typename WarpGemm::BWarpTensor b_warp_tensor{};
const auto b_thread_data = b_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<n_iter, k_iter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
static_for<0, b_warp_tensor.get_thread_buffer_size(), 1>{}([&](auto i) {
b_warp_tensor.get_thread_buffer()(i) = b_thread_data[i];
});
WarpGemm{}.template operator()<OpSelA, OpSelB>(
c_warp_tensors(m_iter)(n_iter),
bit_cast<typename WarpGemm::AWarpTensor>(a_val),
a_warp_tensor,
b_warp_tensor,
scale_a(m_iter)(number<kScalePacked>{}).get_thread_buffer()[0],
bit_cast<typename WarpGemm::BWarpTensor>(b_val),
scale_b(n_iter)(number<kScalePacked>{}).get_thread_buffer()[0]);
});
});
@@ -742,9 +759,9 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
element_wise::PassThrough{},
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
element_wise::PassThrough{},
scale_a_window,
scale_b_window,
num_loop,

View File

@@ -9,11 +9,12 @@
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace ck_tile {
// Default policy for GemmPipelineAgBgCrCompAsync
// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor
// Default policy for MXGemmPipelineAgBgCrCompAsync
// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor
// GetBlockGemm implementation is copied from GemmPipelineAgBgCrCompV4DefaultPolicy
struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompAsyncDefaultPolicy>
// Adds MX scale tile distributions
struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
: public UniversalGemmBasePolicy<MXGemmPipelineAgBgCrCompAsyncDefaultPolicy>
{
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
@@ -134,7 +135,8 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
// MX Scale tile distributions for loading from global memory
// MX Scale tile distributions for loading from global memory
// Using the proven "Flat" patterns from v1 policy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
{
@@ -147,16 +149,17 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
constexpr index_t MPerXdl = WarpTile::at(number<0>{});
constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 4 for 16x16 mfma
// Scale A: [MWarp * MPerXdl, K/32/KXdlPack] for warp-level tile
// Scale A: [MWarp * MPerXdl, K/32/KXdlPack] for warp-level tile
// Distribution: simple 2D for loading int32 packed scales
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarp>, // repeat over NWarp
tile_distribution_encoding<sequence<NWarp>, // repeat over NWarps
tuple<sequence<MWarp, MPerXdl>, // M dimension
sequence<K_Lane, 1>>, // K dimension (int32 vec load)
sequence<K_Lane, 1>>, // K dimension (int32 vec load)
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
sequence<2>, // repeat
sequence<1>>{}); // vec_load
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});
}
template <typename Problem>
@@ -170,17 +173,22 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
constexpr index_t NWarp = BlockWarps::at(number<1>{});
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 4 for 16x16 mfma
static_assert(K_Lane == 4, "K_Lane must be 4 for 16x16 mfma");
static_assert(NPerXdl == 16, "NPerXdl must be 16 for 16x16 mfma");
static_assert(MWarp == 1, "MWarp must be 1 for 16x16 mfma");
// Scale B: [K/32/KXdlPack, NWarp * NPerXdl] for warp-level tile
// Layout is [K, N] where K is packed int32
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarp>, // repeat over MWarp
tuple<sequence<K_Lane, 1>, // K dimension (int32 vec load)
tile_distribution_encoding<sequence<MWarp>, // repeat over MWarps
tuple<sequence<K_Lane, 1>, // K dimension (int32 vec load)
sequence<NWarp, NPerXdl>>, // N dimension
tuple<sequence<2, 1>, sequence<0, 1>>, // which direction
tuple<sequence<0, 1>, sequence<0, 0>>, // which index
sequence<1>, // repeat
sequence<2>>{}); // vec_load
// <repeat, vec_load>
sequence<1>,
sequence<1>>{});
}
};
} // namespace ck_tile