try to enable scale loading in kernel and pipeline

This commit is contained in:
Sami Remes
2026-02-05 09:24:47 +00:00
parent 329eabd73b
commit 6c61804665
3 changed files with 191 additions and 357 deletions

View File

@@ -152,30 +152,40 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
CK_TILE_HOST static constexpr auto
GridSize(const KernelArgs<ScaleM, ScaleN>& kargs)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
if constexpr(UsePersistentKernel)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") +
hipGetErrorName(hipGetLastError()));
if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") +
hipGetErrorName(hipGetLastError()));
if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry<1, MXGemmKernel, remove_cvref_t<decltype(kargs)>>),
KernelBlockSize,
dync_smem_size) != hipSuccess)
throw std::runtime_error(
std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
hipGetErrorName(hipGetLastError()));
if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry<1, MXGemmKernel, remove_cvref_t<decltype(kargs)>>),
KernelBlockSize,
dync_smem_size) != hipSuccess)
throw std::runtime_error(
std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
hipGetErrorName(hipGetLastError()));
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int actual_grid_size = min(persistent_block_size, total_work_tile_cnt);
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, 1);
return dim3(actual_grid_size, 1, 1);
}
else
{
// Non-persistent: use full grid size based on number of tiles
return dim3(total_work_tile_cnt, 1, 1);
}
}
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
@@ -240,26 +250,36 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
// Create scale A block windows following the pattern of MakeABlockWindows
template <typename ScaleM, typename ScaleN>
CK_TILE_DEVICE static auto
MakeScaleABlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs, const index_t block_idx_m)
MakeScaleABlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs, const index_t i_m)
{
auto scale_a = kargs.scale_m_ptr;
static constexpr int BlockScaleSize = ScaleM::GranularityK;
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPack;
const auto scale_k_size = kargs.K / BlockScaleSize;
const auto scale_k_size_packed = scale_k_size / KXdlPack;
// A scale tensor view - simple 2D layout [M, K/BlockScaleSize/KXdlPack]
const auto scale_a_desc = make_naive_tensor_descriptor_packed(
make_tuple(kargs.M, scale_k_packed));
// A scale tensor view - layout [M, scale_k_size_packed] with packed int32_t
// Host packs 4 consecutive e8m0_t scales into one int32_t
// const auto scale_a_desc = make_naive_tensor_descriptor(
// make_tuple(kargs.M, scale_k_size_packed),
// make_tuple(scale_k_size_packed, 1));
const auto scale_a_tensor_view = make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
// const auto scale_a_tensor_view = make_tensor_view<address_space_enum::global>(
// reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
const auto scale_a_tensor_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_a.ptr),
make_tuple(kargs.M, scale_k_size_packed),
make_tuple(scale_k_size_packed, 1));
// Create block window for scale A
// K dimension: KIterPerWarp int32s, each int32 contains 4 scales for K_Lane threads
// i_m is element offset (iM * MPerBlock), not tile index
auto scale_a_block_window = make_tile_window(
scale_a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPack>{}),
{block_idx_m, 0});
{i_m, 0});
return scale_a_block_window;
}
@@ -267,26 +287,35 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
// Create scale B block windows following the pattern of MakeBBlockWindows
template <typename ScaleM, typename ScaleN>
CK_TILE_DEVICE static auto
MakeScaleBBlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs, const index_t block_idx_n)
MakeScaleBBlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs, const index_t i_n)
{
auto scale_b = kargs.scale_n_ptr;
static constexpr int BlockScaleSize = ScaleN::GranularityK;
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPack;
const auto scale_k_size = kargs.K / BlockScaleSize;
const auto scale_k_size_packed = scale_k_size / KXdlPack;
// B scale tensor view - layout [K/BlockScaleSize/KXdlPack, N]
const auto scale_b_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_k_packed, kargs.N));
// B scale tensor view - layout [scale_k_size_packed, N] with packed int32_t
// Host packs 4 consecutive e8m0_t scales into one int32_t
// const auto scale_b_desc = make_naive_tensor_descriptor(
// make_tuple(kargs.N, scale_k_size_packed),
// make_tuple(scale_k_size_packed, 1));
const auto scale_b_tensor_view = make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
// const auto scale_b_tensor_view = make_tensor_view<address_space_enum::global>(
// reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
const auto scale_b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_b.ptr),
make_tuple(kargs.N, scale_k_size_packed),
make_tuple(scale_k_size_packed, 1));
// Create block window for scale B
// i_n is element offset (iN * NPerBlock), not tile index
auto scale_b_block_window = make_tile_window(
scale_b_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPack>{},
number<TilePartitioner::NPerBlock>{}),
{0, block_idx_n});
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPack>{}),
{i_n, 0});
return scale_b_block_window;
}
@@ -301,19 +330,20 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
void* smem_ptr_pong,
const KernelArgs<ScaleM, ScaleN>& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
const index_t i_m,
const index_t i_n)
{
// Create block windows directly, following the new pattern from UniversalGemmKernel
// i_m and i_n are element offsets (iM * MPerBlock, iN * NPerBlock), not tile indices
const auto& a_block_window =
Underlying::MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
Underlying::MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, i_m);
const auto& b_block_window =
Underlying::MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n);
const auto& d_block_window = Underlying::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
Underlying::MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, i_n);
const auto& d_block_window = Underlying::MakeDBlockWindows(ds_ptr, kargs, i_m, i_n);
// Create scale block windows using our new functions
const auto& scale_a_block_window = MakeScaleABlockWindows(kargs, block_idx_m);
const auto& scale_b_block_window = MakeScaleBBlockWindows(kargs, block_idx_n);
const auto& scale_a_block_window = MakeScaleABlockWindows(kargs, i_m);
const auto& scale_b_block_window = MakeScaleBBlockWindows(kargs, i_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
@@ -322,6 +352,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
"ScaleM and ScaleN should have the same GranularityK");
const auto& c_block_tile = MXGemmPipeline{}(a_block_window[number<0>{}],
b_block_window[number<0>{}],
scale_a_block_window,
@@ -332,7 +363,7 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
// Run Epilogue Pipeline - create C block window directly
auto c_block_window =
MakeCBlockWindows<EpiloguePipeline::MemoryOperation>(e_ptr, kargs, block_idx_m, block_idx_n);
MakeCBlockWindows<EpiloguePipeline::MemoryOperation>(e_ptr, kargs, i_m, i_n);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
}
@@ -352,6 +383,11 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
{
const int total_work_tile_cnt = amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N));
// Allocate shared memory OUTSIDE the loop - __shared__ variables must be at function scope
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
// Support both persistent and non-persistent modes
do
{
const auto [iM, iN] =
@@ -377,10 +413,6 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
splitk_batch_offset.bs_k_split_offset[i] / BPackedSize;
});
// allocate LDS
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/load_tile.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
@@ -294,7 +295,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
"B block window has incorrect lengths for defined BLayout!");
////////////// global window & register /////////////////
// A DRAM tile window(s) for load
// A DRAM tile window(s) for load
auto a_tile_windows = generate_tuple(
[&](auto idx) {
@@ -410,33 +411,35 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
constexpr index_t NWarp = BlockWarps::at(I1{});
constexpr index_t MPerXdl = WarpTile::at(I0{});
constexpr index_t NPerXdl = WarpTile::at(I1{});
constexpr index_t KPerXdl = WarpTile::at(I2{});
constexpr index_t ScaleBlockSize = 32;
constexpr index_t ScaleBlockSize = 32; // Each scale covers 32 K elements
// Scale A DRAM Window: [MWarp * MPerXdl, kKPerBlock / 32 / KXdlPack]
// Calculate scale dimensions: KPerBlock elements need KPerBlock/32 scales
// Each int32 packs KXdlPack=4 scales, so we need KPerBlock/32/4 int32s per block
constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize / KXdlPack; // Packed int32s per block
static_assert(ScaleBlockSize == 32, "Scale block size must be 32 for MX format");
// Scale A DRAM Window: [MWarp * MPerXdl, ScaleKDimPerBlock]
// With strided packing: KXdlPack kIters share each int32 via OpSel
auto scale_a_dram_window = make_tile_window(
scale_a_window.get_bottom_tensor_view(),
make_tuple(number<MWarp * MPerXdl>{}, number<KPerBlock / ScaleBlockSize / KXdlPack>{}),
make_tuple(number<MWarp * MPerXdl>{}, number<ScaleKDimPerBlock>{}),
scale_a_window.get_window_origin(),
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
const auto scale_a_dram_step_m = amd_wave_read_first_lane(
scale_a_dram_window.get_load_offset(tuple<number<MWarp * MPerXdl>, number<0>>{}));
const auto scale_a_dram_step_k = amd_wave_read_first_lane(
scale_a_dram_window.get_load_offset(tuple<number<0>, number<KPerBlock / ScaleBlockSize / KXdlPack>>{}));
// Scale B DRAM Window: [kKPerBlock / 32 / KXdlPack, NWarp * NPerXdl]
// Scale B DRAM Window: [ScaleKDimPerBlock, NWarp * NPerXdl]
// With strided packing: KXdlPack kIters share each int32 via OpSel
auto scale_b_dram_window = make_tile_window(
scale_b_window.get_bottom_tensor_view(),
make_tuple(number<KPerBlock / ScaleBlockSize / KXdlPack>{}, number<NWarp * NPerXdl>{}),
make_tuple(number<ScaleKDimPerBlock>{}, number<NWarp * NPerXdl>{}),
scale_b_window.get_window_origin(),
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
const auto scale_b_dram_step_k = amd_wave_read_first_lane(
scale_b_dram_window.get_load_offset(tuple<number<KPerBlock / ScaleBlockSize / KXdlPack>, number<0>>{}));
const auto scale_b_dram_step_n = amd_wave_read_first_lane(
scale_b_dram_window.get_load_offset(tuple<number<0>, number<NWarp * NPerXdl>>{}));
scale_b_dram_window.get_load_offset(tuple<number<NWarp * NPerXdl>, number<0>>{}));
// this pipeline has a pair of LDS buffers per logical tile
// TODO: check for packed size - are these blocks too big?
@@ -447,6 +450,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
// set up LDS tile shapes - always use STORAGE dimensions for K
/// NOTE: flatmm style byte tensor approach:
// constexpr auto a_lds_shape = []() {
@@ -544,23 +548,29 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
static_assert(Policy::template GetSmemSizeB<Problem>() == (KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock, "SmemSizeB size is wrong!");
////////////// MX Scale register tiles (ping-pong buffers) /////////////////
// Calculate scale iterations: each scale covers 32 elements in K
// Each K iteration processes KPerXdl elements
// Each packed int32 contains KXdlPack scales
// Calculate scale iterations for M/N dimensions
constexpr index_t KPerXdl = WarpTile::at(I2{});
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
constexpr index_t ScaleKPackedPerIter = (KIterPerWarp * KPerXdl) / (ScaleBlockSize * KXdlPack);
static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter is wrong!");
// Load a sample scale tile to get the type
// ScaleKPackedPerIter: number of int32s needed to cover all KIterPerWarp iterations
// Each int32 packs 4 scales (via strided packing), OpSel selects byte for kIter
// KXdlPack kIters share one int32, so we need KIterPerWarp/KXdlPack int32s total
constexpr index_t ScaleKPackedPerIter = KIterPerWarp / KXdlPack;
static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter must be positive!");
// Load a sample scale tile to get the type after distribution
auto scale_a_sample = load_tile_with_offset(scale_a_dram_window, tuple<number<0>, number<0>>{});
auto scale_b_sample = load_tile_with_offset(scale_b_dram_window, tuple<number<0>, number<0>>{});
using ScaleTileElementA = remove_cvref_t<decltype(scale_a_sample)>;
using ScaleTileElementB = remove_cvref_t<decltype(scale_b_sample)>;
using ScaleATileType = statically_indexed_array<statically_indexed_array<ScaleTileElementA, ScaleKPackedPerIter>, MIterPerWarp>;
using ScaleBTileType = statically_indexed_array<statically_indexed_array<ScaleTileElementB, ScaleKPackedPerIter>, NIterPerWarp>;
// ScaleATileType: array of distributed tensors, one per M/N iteration
// Each distributed tensor holds ScaleKPackedPerIter int32 elements across threads
using ScaleATileType = statically_indexed_array<ScaleTileElementA, MIterPerWarp>;
using ScaleBTileType = statically_indexed_array<ScaleTileElementB, NIterPerWarp>;
ScaleATileType scale_a_tile_ping, scale_a_tile_pong;
ScaleBTileType scale_b_tile_ping, scale_b_tile_pong;
@@ -569,20 +579,22 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
auto load_scales_ = [&](auto& scale_a, auto& scale_b) {
// Load scales for each M/N iteration
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
scale_a(mIter)(kPacked) = load_tile_with_offset(
scale_a_dram_window, mIter * scale_a_dram_step_m + kPacked * scale_a_dram_step_k);
});
// static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
// scale_a(mIter)(kPacked) = load_tile_with_offset(
// scale_a_dram_window, mIter * scale_a_dram_step_m + kPacked * scale_a_dram_step_k);
// });
scale_a(mIter) = load_tile_with_offset(scale_a_dram_window, make_tuple(mIter * scale_a_dram_step_m, number<0>{}));
});
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
// Scale B is [K/32/KXdlPack, N], so K is first dimension
scale_b(nIter)(kPacked) = load_tile_with_offset(
scale_b_dram_window, kPacked * scale_b_dram_step_k + nIter * scale_b_dram_step_n);
});
// static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
// // Scale B is [K/32/KXdlPack, N], so K is first dimension
// scale_b(nIter)(kPacked) = load_tile_with_offset(
// scale_b_dram_window, kPacked * scale_b_dram_step_k + nIter * scale_b_dram_step_n);
// });
scale_b(nIter) = load_tile_with_offset(scale_b_dram_window, make_tuple(nIter * scale_b_dram_step_n, number<0>{}));
});
move_tile_window(scale_a_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack});
move_tile_window(scale_b_dram_window, {KPerBlock / ScaleBlockSize / KXdlPack, 0});
move_tile_window(scale_b_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack});
};
// constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
@@ -734,7 +746,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
b_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-3) = A(i-3) @ B(i-3) with MX scaling
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_ping;
@@ -763,11 +775,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
b_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-2) = A(i-2) @ B(i-2) with MX scaling
// warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_pong;
ignore = scale_b_tile_pong;
block_gemm(c_block_tile, a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
// block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
// /// TODO: remove these after creating a block gemm with scales
// ignore = scale_a_tile_pong;
// ignore = scale_b_tile_pong;
HotLoopScheduler();
// Load scales for iteration i+2 (pong)
/// TODO: check condition
@@ -787,11 +799,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
// C(num_loop-2) = A(num_loop-2) @ B(num_loop-2) with MX scaling
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_ping;
ignore = scale_b_tile_ping;
block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// /// TODO: remove these after creating a block gemm with scales
// ignore = scale_a_tile_ping;
// ignore = scale_b_tile_ping;
/// TODO: load next scales to ping for the last iteration
}
{
@@ -801,19 +813,19 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling
// warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_pong;
ignore = scale_b_tile_pong;
block_gemm(c_block_tile, a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
// block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
// /// TODO: remove these after creating a block gemm with scales
// ignore = scale_a_tile_pong;
// ignore = scale_b_tile_pong;
}
{
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_ping;
ignore = scale_b_tile_ping;
block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// /// TODO: remove these after creating a block gemm with scales
// ignore = scale_a_tile_ping;
// ignore = scale_b_tile_ping;
}
}
else if(TailNum == TailNumber::Two)
@@ -824,30 +836,30 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_ping;
ignore = scale_b_tile_ping;
block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// /// TODO: remove these after creating a block gemm with scales
// ignore = scale_a_tile_ping;
// ignore = scale_b_tile_ping;
}
{
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
// warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_pong;
ignore = scale_b_tile_pong;
block_gemm(c_block_tile, a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
// block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
// /// TODO: remove these after creating a block gemm with scales
// ignore = scale_a_tile_pong;
// ignore = scale_b_tile_pong;
}
}
else if(TailNum == TailNumber::One)
{
block_sync_lds();
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_ping;
ignore = scale_b_tile_ping;
block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// /// TODO: remove these after creating a block gemm with scales
// ignore = scale_a_tile_ping;
// ignore = scale_b_tile_ping;
__builtin_amdgcn_sched_barrier(0);
}

View File

@@ -25,6 +25,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
static constexpr int MXdlPack = 1; // No M packing
static constexpr int NXdlPack = 1; // No N packing
static constexpr int KXdlPack = 4; // Pack 4 consecutive e8m0 scales in K = 4 bytes = 1 int32
static constexpr int BlockScaleSize = 32; // Each e8m0 scale covers 32 elements in K
// Override vector size methods to ensure compatibility with async buffer operations
// Valid sizes for amd_async_buffer_load are 4, 12, or 16 bytes
@@ -72,222 +73,6 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
return vector_size;
}
// // DRAM tile distributions use STORAGE dimensions (for the storage tensor view)
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
// {
// // using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
// // using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
// // constexpr index_t BlockSize = Problem::kBlockSize;
// // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// // constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
// // constexpr index_t K2 = 16; // 16 bytes
// // constexpr index_t K1 = 128 / K2; // 8
// // constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize
// // constexpr index_t M2 = get_warp_size() / K1; // 8
// // constexpr index_t M1 = BlockSize / get_warp_size(); // 4
// // constexpr index_t M0 = MPerBlock / (M2 * M1);
// // static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
// // static_assert(K0 * K1 * K2 * APackedSize == KPerBlock,
// // "K0, K1, K2 must cover whole KPerBlock!");
// // return make_static_tile_distribution(
// // tile_distribution_encoding< //
// // sequence<1>,
// // tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
// // tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
// // tuple<sequence<1>, sequence<2, 1>>,
// // sequence<1, 2, 2>, // M0,K0,K2
// // sequence<0, 0, 2>>{});
// constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
// /// NOTE: for flatmm style byte tensor, divide KPerBlock by APackedSize to get STORAGE dimensions
// // using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
// // using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
// // constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions
// /// NOTE: use original KPerBlock
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
// constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// using ALayout = remove_cvref_t<
// std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
// if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
// {
// using TileEncodingPattern =
// tile_distribution_encoding_pattern_2d<BlockSize,
// MPerBlock,
// KPerBlock, // Use storage dimensions
// VecLoadSize,
// getATileAccessPattern(),
// NumWaveGroups>;
// return TileEncodingPattern::make_2d_static_tile_distribution();
// }
// else
// {
// static_assert(false, "Not implemented");
// // using TileEncodingPattern =
// // tile_distribution_encoding_pattern_2d<BlockSize,
// // KPerBlock,
// // MPerBlock,
// // VecLoadSize,
// // getATileAccessPattern(),
// // NumWaveGroups>;
// // return TileEncodingPattern::make_2d_static_tile_distribution();
// }
// }
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
// {
// /// NOTE: flatmm style dstr
// // using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
// // using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
// // constexpr index_t BlockSize = Problem::kBlockSize;
// // constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// // constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
// // constexpr index_t K2 = 16; // 16 bytes
// // constexpr index_t K1 = 128 / K2; // 8
// // constexpr index_t K0 = KPerBlock / (K1 * K2 * BPackedSize); // KPerBlock/256/packsize
// // constexpr index_t N2 = get_warp_size() / K1; // 8
// // constexpr index_t N1 = BlockSize / get_warp_size(); // 4
// // constexpr index_t N0 = NPerBlock / (N2 * N1);
// // static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
// // static_assert(K0 * K1 * K2 * BPackedSize == KPerBlock,
// // "K0, K1, K2 must cover whole KPerBlock!");
// // return make_static_tile_distribution(
// // tile_distribution_encoding< //
// // sequence<1>,
// // tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
// // tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
// // tuple<sequence<1>, sequence<2, 1>>,
// // sequence<1, 2, 2>, // N0,K0,K2
// // sequence<0, 0, 2>>{});
// constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
// /// NOTE: for flatmm style byte tensor, divide KPerBlock by BPackedSize to get STORAGE dimensions
// // using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
// // using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
// // constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; // Use STORAGE dimensions
// /// NOTE: use original KPerBlock
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
// constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// using BLayout = remove_cvref_t<
// std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
// if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
// {
// static_assert(false, "Not implemented");
// }
// else
// {
// using TileEncodingPattern =
// tile_distribution_encoding_pattern_2d<BlockSize,
// NPerBlock,
// KPerBlock, // Use storage dimensions
// VecLoadSize,
// getBTileAccessPattern(),
// NumWaveGroups>;
// return TileEncodingPattern::make_2d_static_tile_distribution();
// }
// }
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDSBytes_TileDistribution()
// {
// // static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
// using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
// using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
// constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
// using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
// constexpr index_t MWarps = BlockWarps::at(number<0>{});
// constexpr index_t NWarps = BlockWarps::at(number<1>{});
// constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
// // constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
// constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
// constexpr index_t K_Lane = get_warp_size() / 16; // 4
// constexpr index_t K_Thread = KPerXdl / K_Lane; // 32
// constexpr index_t DWORDx4 = 16;
// constexpr index_t AK1 = DWORDx4 * APackedSize;
// if constexpr(K_Thread == AK1)
// return make_static_tile_distribution(
// tile_distribution_encoding< //
// sequence<NWarps>,
// tuple<sequence<MWarps, 1, MPerXdl>, sequence<K_Lane, AK1 / APackedSize>>,
// tuple<sequence<1, 0>, sequence<2, 1>>,
// tuple<sequence<0, 0>, sequence<0, 2>>,
// sequence<2>,
// sequence<1>>{});
// else
// return make_static_tile_distribution(
// tile_distribution_encoding< //
// sequence<NWarps>,
// tuple<sequence<MWarps, 1, MPerXdl>,
// sequence<K_Thread / AK1, K_Lane, AK1 / APackedSize>>,
// tuple<sequence<1, 0>, sequence<2, 1>>,
// tuple<sequence<0, 0>, sequence<1, 2>>,
// sequence<2, 2>,
// sequence<0, 2>>{});
// }
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLDSBytes_TileDistribution()
// {
// // static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
// using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
// using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
// constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
// using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
// constexpr index_t MWarps = BlockWarps::at(number<0>{});
// constexpr index_t NWarps = BlockWarps::at(number<1>{});
// // constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
// constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
// constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
// constexpr index_t K_Lane = get_warp_size() / 16; // 4
// constexpr index_t K_Thread = KPerXdl / K_Lane; // 32
// constexpr index_t DWORDx4 = 16;
// constexpr index_t BK1 = DWORDx4 * BPackedSize;
// if constexpr(K_Thread == BK1)
// return make_static_tile_distribution(
// tile_distribution_encoding< //
// sequence<MWarps>,
// tuple<sequence<NWarps, 1, NPerXdl>, sequence<K_Lane, BK1 / BPackedSize>>,
// tuple<sequence<1, 0>, sequence<2, 1>>,
// tuple<sequence<0, 0>, sequence<0, 2>>,
// sequence<2>,
// sequence<1>>{});
// else
// return make_static_tile_distribution(
// tile_distribution_encoding< //
// sequence<MWarps>,
// tuple<sequence<NWarps, 1, NPerXdl>,
// sequence<K_Thread / BK1, K_Lane, BK1 / BPackedSize>>,
// tuple<sequence<1, 0>, sequence<2, 1>>,
// tuple<sequence<0, 0>, sequence<1, 2>>,
// sequence<2, 2>,
// sequence<0, 2>>{});
// }
template <typename Problem,
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
@@ -413,8 +198,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
}
// MX Scale tile distributions for loading from global memory
// Using the proven "Flat" patterns from v1 policy
// MX Scale tile distributions for loading from global memory
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
{
@@ -425,20 +209,23 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
constexpr index_t MWarp = BlockWarps::at(number<0>{});
constexpr index_t NWarp = BlockWarps::at(number<1>{});
constexpr index_t MPerXdl = WarpTile::at(number<0>{});
constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 4 for 16x16 mfma
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t ScaleKDimPerBlock = KPerBlock / BlockScaleSize / KXdlPack; // int32s per block
constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 64/16 = 4 threads in K dimension
// constexpr index_t KPackedElementsPerThread = ScaleKDimPerBlock / K_Lane; // 4/4 = 1 for K=512
// Scale A: [MWarp * MPerXdl, K/32/KXdlPack] for warp-level tile
// Distribution: simple 2D for loading int32 packed scales
// TODO: check which layout to actually use (could use KxN)
// Scale A: [MWarp * MPerXdl, ScaleKDimPerBlock] warp-level tile
// For K=512: [16, 4], distribute 4 int32s across 4 K_Lane threads (1 each)
// Strided packing: thread at K_lane=k gets one int32 with scales for all kIters at K position k
// Distribution: Replicate in M dimension, distribute in K dimension (no vectorization - scalar loads)
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarp>, // repeat over NWarps
tuple<sequence<MWarp, MPerXdl>, // M dimension
sequence<K_Lane, 1>>, // K dimension (int32 vec load)
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});
tile_distribution_encoding<sequence<NWarp>, // repeat over NWarps
tuple<sequence<MWarp, MPerXdl>, // M dimension
sequence<ScaleKDimPerBlock, K_Lane>>, // K dimension
tuple<sequence<1, 0>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, MPerXdl>
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock
sequence<0>>{});
}
template <typename Problem>
@@ -451,20 +238,23 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
constexpr index_t MWarp = BlockWarps::at(number<0>{});
constexpr index_t NWarp = BlockWarps::at(number<1>{});
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 4 for 16x16 mfma
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t ScaleKDimPerBlock = KPerBlock / BlockScaleSize / KXdlPack; // int32s per block
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 64/16 = 4 threads in K dimension
// constexpr index_t KPackedElementsPerThread = ScaleKDimPerBlock / K_Lane; // 4/4 = 1 for K=512
// Scale B: [K/32/KXdlPack, NWarp * NPerXdl] for warp-level tile
// Layout is [K, N] where K is packed int32
// TODO: check which layout to actually use (could use KxN)
// Scale B: [ScaleKDimPerBlock, NWarp * NPerXdl] warp-level tile
// For K=512: [4, 64], distribute 4 int32s across 4 K_Lane threads (1 each)
// Strided packing: thread at K_lane=k gets one int32 with scales for all kIters at K position k
// Distribution: Distribute in K dimension (no vectorization - scalar loads), replicate in N dimension
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarp>, // repeat over MWarps
tuple<sequence<K_Lane, 1>, // K dimension (int32 vec load)
sequence<NWarp, NPerXdl>>, // N dimension
tuple<sequence<2, 1>, sequence<0, 1>>, // which direction
tuple<sequence<0, 1>, sequence<0, 0>>, // which index
// <repeat, vec_load>
sequence<1>,
sequence<1>>{});
tile_distribution_encoding<sequence<MWarp>, // repeat over MWarps
tuple<sequence<NWarp, NPerXdl>, // N dimension
sequence<ScaleKDimPerBlock, K_Lane>>, // K dimension
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, NPerXdl>
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock
sequence<0>>{});
}
};
} // namespace ck_tile