init=1 init=2 working, some scales are still wrong as init=0 failing

This commit is contained in:
Sami Remes
2026-02-05 10:28:49 +00:00
parent 6c61804665
commit 350022827f
3 changed files with 26 additions and 30 deletions

View File

@@ -295,21 +295,17 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
const auto scale_k_size = kargs.K / BlockScaleSize;
const auto scale_k_size_packed = scale_k_size / KXdlPack;
// B scale tensor view - layout [scale_k_size_packed, N] with packed int32_t
// Host packs 4 consecutive e8m0_t scales into one int32_t
// const auto scale_b_desc = make_naive_tensor_descriptor(
// make_tuple(kargs.N, scale_k_size_packed),
// make_tuple(scale_k_size_packed, 1));
// const auto scale_b_tensor_view = make_tensor_view<address_space_enum::global>(
// reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
// B scale tensor view - for col-major B, we access as [N, K] for better coalescing
// Host stores as [K/32, N] col-major = [N, K/32] row-major from access perspective
// After packing: stored as [K/128, N] col-major
// But we create view as [N, K/128] to match the access pattern (each thread handles one N)
const auto scale_b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_b.ptr),
make_tuple(kargs.N, scale_k_size_packed),
make_tuple(scale_k_size_packed, 1));
make_tuple(kargs.N, scale_k_size_packed), // [N, K/32/4] for access
make_tuple(scale_k_size_packed, 1)); // stride to match col-major storage
// Create block window for scale B
// Tile window shape matches access pattern: [NPerBlock, KPerBlock/32/4]
// i_n is element offset (iN * NPerBlock), not tile index
auto scale_b_block_window = make_tile_window(
scale_b_tensor_view,

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/tensor/load_tile.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
@@ -450,7 +451,6 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
// set up LDS tile shapes - always use STORAGE dimensions for K
/// NOTE: flatmm style byte tensor approach:
// constexpr auto a_lds_shape = []() {
@@ -586,13 +586,12 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
scale_a(mIter) = load_tile_with_offset(scale_a_dram_window, make_tuple(mIter * scale_a_dram_step_m, number<0>{}));
});
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
// // Scale B is [K/32/KXdlPack, N], so K is first dimension
// scale_b(nIter)(kPacked) = load_tile_with_offset(
// scale_b_dram_window, kPacked * scale_b_dram_step_k + nIter * scale_b_dram_step_n);
// });
// Scale B viewed as [N, K], so N is first dimension
scale_b(nIter) = load_tile_with_offset(scale_b_dram_window, make_tuple(nIter * scale_b_dram_step_n, number<0>{}));
});
// Advance to next KPerBlock
// Scale A: [M, K] -> advance in K (second dimension)
// Scale B: viewed as [N, K] -> advance in K (second dimension)
move_tile_window(scale_a_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack});
move_tile_window(scale_b_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack});
};
@@ -746,11 +745,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
b_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-3) = A(i-3) @ B(i-3) with MX scaling
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_ping;
ignore = scale_b_tile_ping;
block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
// /// TODO: remove these after creating a block gemm with scales
// ignore = scale_a_tile_ping;
// ignore = scale_b_tile_ping;
HotLoopScheduler();
// Load scales for iteration i+2 (ping)
if (i_global_read + 2 < num_loop) {

View File

@@ -243,17 +243,18 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 64/16 = 4 threads in K dimension
// constexpr index_t KPackedElementsPerThread = ScaleKDimPerBlock / K_Lane; // 4/4 = 1 for K=512
// Scale B: [ScaleKDimPerBlock, NWarp * NPerXdl] warp-level tile
// For K=512: [4, 64], distribute 4 int32s across 4 K_Lane threads (1 each)
// Scale B: [NWarp * NPerXdl, ScaleKDimPerBlock] warp-level tile
// Viewed as [N, K] = [64, 4] for K=512 (access pattern, not storage)
// For K=512: [64, 4], distribute 4 int32s across 4 K_Lane threads (1 each)
// Strided packing: thread at K_lane=k gets one int32 with scales for all kIters at K position k
// Distribution: Distribute in K dimension (no vectorization - scalar loads), replicate in N dimension
// Distribution: Replicate in N dimension, distribute in K dimension
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarp>, // repeat over MWarps
tuple<sequence<NWarp, NPerXdl>, // N dimension
sequence<ScaleKDimPerBlock, K_Lane>>, // K dimension
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, NPerXdl>
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock
tuple<sequence<NWarp, NPerXdl>, // N dimension (first)
sequence<ScaleKDimPerBlock, K_Lane>>, // K dimension (second)
tuple<sequence<0, 1>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<1, 1>>, // which index
sequence<2>, // replicate N
sequence<0>>{});
}
};