mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
init=1 init=2 working, some scales are still wrong as init=0 failing
This commit is contained in:
@@ -295,21 +295,17 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
const auto scale_k_size = kargs.K / BlockScaleSize;
|
||||
const auto scale_k_size_packed = scale_k_size / KXdlPack;
|
||||
|
||||
// B scale tensor view - layout [scale_k_size_packed, N] with packed int32_t
|
||||
// Host packs 4 consecutive e8m0_t scales into one int32_t
|
||||
// const auto scale_b_desc = make_naive_tensor_descriptor(
|
||||
// make_tuple(kargs.N, scale_k_size_packed),
|
||||
// make_tuple(scale_k_size_packed, 1));
|
||||
|
||||
// const auto scale_b_tensor_view = make_tensor_view<address_space_enum::global>(
|
||||
// reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
|
||||
|
||||
// B scale tensor view - for col-major B, we access as [N, K] for better coalescing
|
||||
// Host stores as [K/32, N] col-major = [N, K/32] row-major from access perspective
|
||||
// After packing: stored as [K/128, N] col-major
|
||||
// But we create view as [N, K/128] to match the access pattern (each thread handles one N)
|
||||
const auto scale_b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_b.ptr),
|
||||
make_tuple(kargs.N, scale_k_size_packed),
|
||||
make_tuple(scale_k_size_packed, 1));
|
||||
make_tuple(kargs.N, scale_k_size_packed), // [N, K/32/4] for access
|
||||
make_tuple(scale_k_size_packed, 1)); // stride to match col-major storage
|
||||
|
||||
// Create block window for scale B
|
||||
// Tile window shape matches access pattern: [NPerBlock, KPerBlock/32/4]
|
||||
// i_n is element offset (iN * NPerBlock), not tile index
|
||||
auto scale_b_block_window = make_tile_window(
|
||||
scale_b_tensor_view,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/tensor/load_tile.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
@@ -450,7 +451,6 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
|
||||
|
||||
|
||||
// set up LDS tile shapes - always use STORAGE dimensions for K
|
||||
/// NOTE: flatmm style byte tensor approach:
|
||||
// constexpr auto a_lds_shape = []() {
|
||||
@@ -586,13 +586,12 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
scale_a(mIter) = load_tile_with_offset(scale_a_dram_window, make_tuple(mIter * scale_a_dram_step_m, number<0>{}));
|
||||
});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) {
|
||||
// // Scale B is [K/32/KXdlPack, N], so K is first dimension
|
||||
// scale_b(nIter)(kPacked) = load_tile_with_offset(
|
||||
// scale_b_dram_window, kPacked * scale_b_dram_step_k + nIter * scale_b_dram_step_n);
|
||||
// });
|
||||
// Scale B viewed as [N, K], so N is first dimension
|
||||
scale_b(nIter) = load_tile_with_offset(scale_b_dram_window, make_tuple(nIter * scale_b_dram_step_n, number<0>{}));
|
||||
});
|
||||
// Advance to next KPerBlock
|
||||
// Scale A: [M, K] -> advance in K (second dimension)
|
||||
// Scale B: viewed as [N, K] -> advance in K (second dimension)
|
||||
move_tile_window(scale_a_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack});
|
||||
move_tile_window(scale_b_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack});
|
||||
};
|
||||
@@ -746,11 +745,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
b_tile_windows[number<0>{}],
|
||||
b_dram_tile_window_step);
|
||||
// C(i-3) = A(i-3) @ B(i-3) with MX scaling
|
||||
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_ping;
|
||||
ignore = scale_b_tile_ping;
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
// block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
// /// TODO: remove these after creating a block gemm with scales
|
||||
// ignore = scale_a_tile_ping;
|
||||
// ignore = scale_b_tile_ping;
|
||||
HotLoopScheduler();
|
||||
// Load scales for iteration i+2 (ping)
|
||||
if (i_global_read + 2 < num_loop) {
|
||||
|
||||
@@ -243,17 +243,18 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 64/16 = 4 threads in K dimension
|
||||
// constexpr index_t KPackedElementsPerThread = ScaleKDimPerBlock / K_Lane; // 4/4 = 1 for K=512
|
||||
|
||||
// Scale B: [ScaleKDimPerBlock, NWarp * NPerXdl] warp-level tile
|
||||
// For K=512: [4, 64], distribute 4 int32s across 4 K_Lane threads (1 each)
|
||||
// Scale B: [NWarp * NPerXdl, ScaleKDimPerBlock] warp-level tile
|
||||
// Viewed as [N, K] = [64, 4] for K=512 (access pattern, not storage)
|
||||
// For K=512: [64, 4], distribute 4 int32s across 4 K_Lane threads (1 each)
|
||||
// Strided packing: thread at K_lane=k gets one int32 with scales for all kIters at K position k
|
||||
// Distribution: Distribute in K dimension (no vectorization - scalar loads), replicate in N dimension
|
||||
// Distribution: Replicate in N dimension, distribute in K dimension
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarp>, // repeat over MWarps
|
||||
tuple<sequence<NWarp, NPerXdl>, // N dimension
|
||||
sequence<ScaleKDimPerBlock, K_Lane>>, // K dimension
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, NPerXdl>
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock
|
||||
tuple<sequence<NWarp, NPerXdl>, // N dimension (first)
|
||||
sequence<ScaleKDimPerBlock, K_Lane>>, // K dimension (second)
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>, // which index
|
||||
sequence<2>, // replicate N
|
||||
sequence<0>>{});
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user