diff --git a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp index 97e26e756f..67873ab4f8 100644 --- a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp +++ b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp @@ -152,30 +152,40 @@ struct MXGemmKernel : UniversalGemmKernel& kargs) { - hipDeviceProp_t prop; - int deviceId = 0; // default device + const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N); + + if constexpr(UsePersistentKernel) + { + hipDeviceProp_t prop; + int deviceId = 0; // default device - int dync_smem_size = 0; - int maxActiveBlocksPerCU = 0; + int dync_smem_size = 0; + int maxActiveBlocksPerCU = 0; - if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess) - throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") + - hipGetErrorName(hipGetLastError())); + if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess) + throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") + + hipGetErrorName(hipGetLastError())); - if(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &maxActiveBlocksPerCU, - reinterpret_cast( - kentry<1, MXGemmKernel, remove_cvref_t>), - KernelBlockSize, - dync_smem_size) != hipSuccess) - throw std::runtime_error( - std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") + - hipGetErrorName(hipGetLastError())); + if(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &maxActiveBlocksPerCU, + reinterpret_cast( + kentry<1, MXGemmKernel, remove_cvref_t>), + KernelBlockSize, + dync_smem_size) != hipSuccess) + throw std::runtime_error( + std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") + + hipGetErrorName(hipGetLastError())); - const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU; - const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N); + const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU; + const int actual_grid_size = min(persistent_block_size, total_work_tile_cnt); - return dim3(min(persistent_block_size, total_work_tile_cnt), 1, 1); + return dim3(actual_grid_size, 1, 1); + } + else + { + // Non-persistent: use full grid size based on number of tiles + return dim3(total_work_tile_cnt, 1, 1); + } } using SplitKBatchOffset = typename Underlying::SplitKBatchOffset; @@ -240,26 +250,36 @@ struct MXGemmKernel : UniversalGemmKernel CK_TILE_DEVICE static auto - MakeScaleABlockWindows(const KernelArgs& kargs, const index_t block_idx_m) + MakeScaleABlockWindows(const KernelArgs& kargs, const index_t i_m) { auto scale_a = kargs.scale_m_ptr; static constexpr int BlockScaleSize = ScaleM::GranularityK; - const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPack; + const auto scale_k_size = kargs.K / BlockScaleSize; + const auto scale_k_size_packed = scale_k_size / KXdlPack; - // A scale tensor view - simple 2D layout [M, K/BlockScaleSize/KXdlPack] - const auto scale_a_desc = make_naive_tensor_descriptor_packed( - make_tuple(kargs.M, scale_k_packed)); + // A scale tensor view - layout [M, scale_k_size_packed] with packed int32_t + // Host packs 4 consecutive e8m0_t scales into one int32_t + // const auto scale_a_desc = make_naive_tensor_descriptor( + // make_tuple(kargs.M, scale_k_size_packed), + // make_tuple(scale_k_size_packed, 1)); - const auto scale_a_tensor_view = make_tensor_view( - reinterpret_cast(scale_a.ptr), scale_a_desc); + // const auto scale_a_tensor_view = make_tensor_view( + // reinterpret_cast(scale_a.ptr), scale_a_desc); + + const auto scale_a_tensor_view = make_naive_tensor_view( + reinterpret_cast(scale_a.ptr), + make_tuple(kargs.M, scale_k_size_packed), + make_tuple(scale_k_size_packed, 1)); // Create block window for scale A + // K dimension: KIterPerWarp int32s, each int32 contains 4 scales for K_Lane threads + // i_m is element offset (iM * MPerBlock), not tile index auto scale_a_block_window = make_tile_window( scale_a_tensor_view, make_tuple(number{}, number{}), - {block_idx_m, 0}); + {i_m, 0}); return scale_a_block_window; } @@ -267,26 +287,35 @@ struct MXGemmKernel : UniversalGemmKernel CK_TILE_DEVICE static auto - MakeScaleBBlockWindows(const KernelArgs& kargs, const index_t block_idx_n) + MakeScaleBBlockWindows(const KernelArgs& kargs, const index_t i_n) { auto scale_b = kargs.scale_n_ptr; static constexpr int BlockScaleSize = ScaleN::GranularityK; - const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPack; + const auto scale_k_size = kargs.K / BlockScaleSize; + const auto scale_k_size_packed = scale_k_size / KXdlPack; - // B scale tensor view - layout [K/BlockScaleSize/KXdlPack, N] - const auto scale_b_desc = make_naive_tensor_descriptor_packed( - make_tuple(scale_k_packed, kargs.N)); + // B scale tensor view - layout [scale_k_size_packed, N] with packed int32_t + // Host packs 4 consecutive e8m0_t scales into one int32_t + // const auto scale_b_desc = make_naive_tensor_descriptor( + // make_tuple(kargs.N, scale_k_size_packed), + // make_tuple(scale_k_size_packed, 1)); - const auto scale_b_tensor_view = make_tensor_view( - reinterpret_cast(scale_b.ptr), scale_b_desc); + // const auto scale_b_tensor_view = make_tensor_view( + // reinterpret_cast(scale_b.ptr), scale_b_desc); + + const auto scale_b_tensor_view = make_naive_tensor_view( + reinterpret_cast(scale_b.ptr), + make_tuple(kargs.N, scale_k_size_packed), + make_tuple(scale_k_size_packed, 1)); // Create block window for scale B + // i_n is element offset (iN * NPerBlock), not tile index auto scale_b_block_window = make_tile_window( scale_b_tensor_view, - make_tuple(number{}, - number{}), - {0, block_idx_n}); + make_tuple(number{}, + number{}), + {i_n, 0}); return scale_b_block_window; } @@ -301,19 +330,20 @@ struct MXGemmKernel : UniversalGemmKernel& kargs, const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) + const index_t i_m, + const index_t i_n) { // Create block windows directly, following the new pattern from UniversalGemmKernel + // i_m and i_n are element offsets (iM * MPerBlock, iN * NPerBlock), not tile indices const auto& a_block_window = - Underlying::MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + Underlying::MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, i_m); const auto& b_block_window = - Underlying::MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); - const auto& d_block_window = Underlying::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); + Underlying::MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, i_n); + const auto& d_block_window = Underlying::MakeDBlockWindows(ds_ptr, kargs, i_m, i_n); // Create scale block windows using our new functions - const auto& scale_a_block_window = MakeScaleABlockWindows(kargs, block_idx_m); - const auto& scale_b_block_window = MakeScaleBBlockWindows(kargs, block_idx_n); + const auto& scale_a_block_window = MakeScaleABlockWindows(kargs, i_m); + const auto& scale_b_block_window = MakeScaleBBlockWindows(kargs, i_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); @@ -322,6 +352,7 @@ struct MXGemmKernel : UniversalGemmKernel{}], b_block_window[number<0>{}], scale_a_block_window, @@ -332,7 +363,7 @@ struct MXGemmKernel : UniversalGemmKernel(e_ptr, kargs, block_idx_m, block_idx_n); + MakeCBlockWindows(e_ptr, kargs, i_m, i_n); EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping); } @@ -352,6 +383,11 @@ struct MXGemmKernel : UniversalGemmKernel::value)) diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 9af8654e5b..2115f37bed 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/load_tile.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" @@ -294,7 +295,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< "B block window has incorrect lengths for defined BLayout!"); ////////////// global window & register ///////////////// - // A DRAM tile window(s) for load + // A DRAM tile window(s) for load auto a_tile_windows = generate_tuple( [&](auto idx) { @@ -410,33 +411,35 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< constexpr index_t NWarp = BlockWarps::at(I1{}); constexpr index_t MPerXdl = WarpTile::at(I0{}); constexpr index_t NPerXdl = WarpTile::at(I1{}); - constexpr index_t KPerXdl = WarpTile::at(I2{}); - constexpr index_t ScaleBlockSize = 32; + constexpr index_t ScaleBlockSize = 32; // Each scale covers 32 K elements - // Scale A DRAM Window: [MWarp * MPerXdl, kKPerBlock / 32 / KXdlPack] + // Calculate scale dimensions: KPerBlock elements need KPerBlock/32 scales + // Each int32 packs KXdlPack=4 scales, so we need KPerBlock/32/4 int32s per block + constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize / KXdlPack; // Packed int32s per block + static_assert(ScaleBlockSize == 32, "Scale block size must be 32 for MX format"); + + // Scale A DRAM Window: [MWarp * MPerXdl, ScaleKDimPerBlock] + // With strided packing: KXdlPack kIters share each int32 via OpSel auto scale_a_dram_window = make_tile_window( scale_a_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), scale_a_window.get_window_origin(), Policy::template MakeMX_ScaleA_DramTileDistribution()); const auto scale_a_dram_step_m = amd_wave_read_first_lane( scale_a_dram_window.get_load_offset(tuple, number<0>>{})); - const auto scale_a_dram_step_k = amd_wave_read_first_lane( - scale_a_dram_window.get_load_offset(tuple, number>{})); - // Scale B DRAM Window: [kKPerBlock / 32 / KXdlPack, NWarp * NPerXdl] + // Scale B DRAM Window: [ScaleKDimPerBlock, NWarp * NPerXdl] + // With strided packing: KXdlPack kIters share each int32 via OpSel auto scale_b_dram_window = make_tile_window( scale_b_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), scale_b_window.get_window_origin(), Policy::template MakeMX_ScaleB_DramTileDistribution()); - const auto scale_b_dram_step_k = amd_wave_read_first_lane( - scale_b_dram_window.get_load_offset(tuple, number<0>>{})); const auto scale_b_dram_step_n = amd_wave_read_first_lane( - scale_b_dram_window.get_load_offset(tuple, number>{})); + scale_b_dram_window.get_load_offset(tuple, number<0>>{})); // this pipeline has a pair of LDS buffers per logical tile // TODO: check for packed size - are these blocks too big? @@ -447,6 +450,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); + // set up LDS tile shapes - always use STORAGE dimensions for K /// NOTE: flatmm style byte tensor approach: // constexpr auto a_lds_shape = []() { @@ -544,23 +548,29 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< static_assert(Policy::template GetSmemSizeB() == (KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock, "SmemSizeB size is wrong!"); ////////////// MX Scale register tiles (ping-pong buffers) ///////////////// - // Calculate scale iterations: each scale covers 32 elements in K - // Each K iteration processes KPerXdl elements - // Each packed int32 contains KXdlPack scales + // Calculate scale iterations for M/N dimensions + constexpr index_t KPerXdl = WarpTile::at(I2{}); constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); - constexpr index_t ScaleKPackedPerIter = (KIterPerWarp * KPerXdl) / (ScaleBlockSize * KXdlPack); - static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter is wrong!"); - // Load a sample scale tile to get the type + // ScaleKPackedPerIter: number of int32s needed to cover all KIterPerWarp iterations + // Each int32 packs 4 scales (via strided packing), OpSel selects byte for kIter + // KXdlPack kIters share one int32, so we need KIterPerWarp/KXdlPack int32s total + constexpr index_t ScaleKPackedPerIter = KIterPerWarp / KXdlPack; + static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter must be positive!"); + + // Load a sample scale tile to get the type after distribution auto scale_a_sample = load_tile_with_offset(scale_a_dram_window, tuple, number<0>>{}); auto scale_b_sample = load_tile_with_offset(scale_b_dram_window, tuple, number<0>>{}); using ScaleTileElementA = remove_cvref_t; using ScaleTileElementB = remove_cvref_t; - using ScaleATileType = statically_indexed_array, MIterPerWarp>; - using ScaleBTileType = statically_indexed_array, NIterPerWarp>; + + // ScaleATileType: array of distributed tensors, one per M/N iteration + // Each distributed tensor holds ScaleKPackedPerIter int32 elements across threads + using ScaleATileType = statically_indexed_array; + using ScaleBTileType = statically_indexed_array; ScaleATileType scale_a_tile_ping, scale_a_tile_pong; ScaleBTileType scale_b_tile_ping, scale_b_tile_pong; @@ -569,20 +579,22 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< auto load_scales_ = [&](auto& scale_a, auto& scale_b) { // Load scales for each M/N iteration static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) { - scale_a(mIter)(kPacked) = load_tile_with_offset( - scale_a_dram_window, mIter * scale_a_dram_step_m + kPacked * scale_a_dram_step_k); - }); + // static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) { + // scale_a(mIter)(kPacked) = load_tile_with_offset( + // scale_a_dram_window, mIter * scale_a_dram_step_m + kPacked * scale_a_dram_step_k); + // }); + scale_a(mIter) = load_tile_with_offset(scale_a_dram_window, make_tuple(mIter * scale_a_dram_step_m, number<0>{})); }); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) { - // Scale B is [K/32/KXdlPack, N], so K is first dimension - scale_b(nIter)(kPacked) = load_tile_with_offset( - scale_b_dram_window, kPacked * scale_b_dram_step_k + nIter * scale_b_dram_step_n); - }); + // static_for<0, ScaleKPackedPerIter, 1>{}([&](auto kPacked) { + // // Scale B is [K/32/KXdlPack, N], so K is first dimension + // scale_b(nIter)(kPacked) = load_tile_with_offset( + // scale_b_dram_window, kPacked * scale_b_dram_step_k + nIter * scale_b_dram_step_n); + // }); + scale_b(nIter) = load_tile_with_offset(scale_b_dram_window, make_tuple(nIter * scale_b_dram_step_n, number<0>{})); }); move_tile_window(scale_a_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack}); - move_tile_window(scale_b_dram_window, {KPerBlock / ScaleBlockSize / KXdlPack, 0}); + move_tile_window(scale_b_dram_window, {0, KPerBlock / ScaleBlockSize / KXdlPack}); }; // constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() { @@ -734,7 +746,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< b_tile_windows[number<0>{}], b_dram_tile_window_step); // C(i-3) = A(i-3) @ B(i-3) with MX scaling - // warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + // block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); block_gemm(c_block_tile, a_block_tile0, b_block_tile0); /// TODO: remove these after creating a block gemm with scales ignore = scale_a_tile_ping; @@ -763,11 +775,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< b_tile_windows[number<0>{}], b_dram_tile_window_step); // C(i-2) = A(i-2) @ B(i-2) with MX scaling - // warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); - block_gemm(c_block_tile, a_block_tile1, b_block_tile1); - /// TODO: remove these after creating a block gemm with scales - ignore = scale_a_tile_pong; - ignore = scale_b_tile_pong; + block_gemm(c_block_tile, a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); + // block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + // /// TODO: remove these after creating a block gemm with scales + // ignore = scale_a_tile_pong; + // ignore = scale_b_tile_pong; HotLoopScheduler(); // Load scales for iteration i+2 (pong) /// TODO: check condition @@ -787,11 +799,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); // C(num_loop-2) = A(num_loop-2) @ B(num_loop-2) with MX scaling - // warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); - block_gemm(c_block_tile, a_block_tile0, b_block_tile0); - /// TODO: remove these after creating a block gemm with scales - ignore = scale_a_tile_ping; - ignore = scale_b_tile_ping; + block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + // block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + // /// TODO: remove these after creating a block gemm with scales + // ignore = scale_a_tile_ping; + // ignore = scale_b_tile_ping; /// TODO: load next scales to ping for the last iteration } { @@ -801,19 +813,19 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v); Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v); // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling - // warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); - block_gemm(c_block_tile, a_block_tile1, b_block_tile1); - /// TODO: remove these after creating a block gemm with scales - ignore = scale_a_tile_pong; - ignore = scale_b_tile_pong; + block_gemm(c_block_tile, a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); + // block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + // /// TODO: remove these after creating a block gemm with scales + // ignore = scale_a_tile_pong; + // ignore = scale_b_tile_pong; } { // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling - // warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); - block_gemm(c_block_tile, a_block_tile0, b_block_tile0); - /// TODO: remove these after creating a block gemm with scales - ignore = scale_a_tile_ping; - ignore = scale_b_tile_ping; + block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + // block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + // /// TODO: remove these after creating a block gemm with scales + // ignore = scale_a_tile_ping; + // ignore = scale_b_tile_ping; } } else if(TailNum == TailNumber::Two) @@ -824,30 +836,30 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling - // warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); - block_gemm(c_block_tile, a_block_tile0, b_block_tile0); - /// TODO: remove these after creating a block gemm with scales - ignore = scale_a_tile_ping; - ignore = scale_b_tile_ping; + block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + // block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + // /// TODO: remove these after creating a block gemm with scales + // ignore = scale_a_tile_ping; + // ignore = scale_b_tile_ping; } { // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling - // warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); - block_gemm(c_block_tile, a_block_tile1, b_block_tile1); - /// TODO: remove these after creating a block gemm with scales - ignore = scale_a_tile_pong; - ignore = scale_b_tile_pong; + block_gemm(c_block_tile, a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); + // block_gemm(c_block_tile, a_block_tile1, b_block_tile1); + // /// TODO: remove these after creating a block gemm with scales + // ignore = scale_a_tile_pong; + // ignore = scale_b_tile_pong; } } else if(TailNum == TailNumber::One) { block_sync_lds(); // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling - // warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); - block_gemm(c_block_tile, a_block_tile0, b_block_tile0); - /// TODO: remove these after creating a block gemm with scales - ignore = scale_a_tile_ping; - ignore = scale_b_tile_ping; + block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); + // block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + // /// TODO: remove these after creating a block gemm with scales + // ignore = scale_a_tile_ping; + // ignore = scale_b_tile_ping; __builtin_amdgcn_sched_barrier(0); } diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index 55c7efb10a..72a9b09571 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -25,6 +25,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy static constexpr int MXdlPack = 1; // No M packing static constexpr int NXdlPack = 1; // No N packing static constexpr int KXdlPack = 4; // Pack 4 consecutive e8m0 scales in K = 4 bytes = 1 int32 + static constexpr int BlockScaleSize = 32; // Each e8m0 scale covers 32 elements in K // Override vector size methods to ensure compatibility with async buffer operations // Valid sizes for amd_async_buffer_load are 4, 12, or 16 bytes @@ -72,222 +73,6 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy return vector_size; } - // // DRAM tile distributions use STORAGE dimensions (for the storage tensor view) - // template - // CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() - // { - // // using AsDataType = remove_cvref_t; - // // using ADataType = remove_cvref_t{}, AsDataType>>; - - // // constexpr index_t BlockSize = Problem::kBlockSize; - // // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - // // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - // // constexpr index_t APackedSize = numeric_traits>::PackedSize; - - // // constexpr index_t K2 = 16; // 16 bytes - // // constexpr index_t K1 = 128 / K2; // 8 - // // constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize - - // // constexpr index_t M2 = get_warp_size() / K1; // 8 - // // constexpr index_t M1 = BlockSize / get_warp_size(); // 4 - // // constexpr index_t M0 = MPerBlock / (M2 * M1); - - // // static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); - // // static_assert(K0 * K1 * K2 * APackedSize == KPerBlock, - // // "K0, K1, K2 must cover whole KPerBlock!"); - - // // return make_static_tile_distribution( - // // tile_distribution_encoding< // - // // sequence<1>, - // // tuple, sequence>, // ?,4,8 1,8,32 or 2,8,16 - // // tuple, sequence<1, 2>>, // M1 M2,K1 - // // tuple, sequence<2, 1>>, - // // sequence<1, 2, 2>, // M0,K0,K2 - // // sequence<0, 0, 2>>{}); - // constexpr index_t BlockSize = Problem::kBlockSize; - // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - // /// NOTE: for flatmm style byte tensor, divide KPerBlock by APackedSize to get STORAGE dimensions - // // using AsDataType = remove_cvref_t; - // // using ADataType = remove_cvref_t{}, AsDataType>>; - // // constexpr index_t APackedSize = numeric_traits>::PackedSize; - // // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions - // /// NOTE: use original KPerBlock - // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - // constexpr index_t VecLoadSize = GetVectorSizeA(); - // constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - - // using ALayout = remove_cvref_t< - // std::tuple_element_t{}, remove_cvref_t>>; - - - // if constexpr(std::is_same_v) - // { - // using TileEncodingPattern = - // tile_distribution_encoding_pattern_2d; - // return TileEncodingPattern::make_2d_static_tile_distribution(); - // } - // else - // { - // static_assert(false, "Not implemented"); - // // using TileEncodingPattern = - // // tile_distribution_encoding_pattern_2d; - // // return TileEncodingPattern::make_2d_static_tile_distribution(); - // } - // } - - // template - // CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() - // { - // /// NOTE: flatmm style dstr - // // using BsDataType = remove_cvref_t; - // // using BDataType = remove_cvref_t{}, BsDataType>>; - - // // constexpr index_t BlockSize = Problem::kBlockSize; - // // constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - // // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - // // constexpr index_t BPackedSize = numeric_traits>::PackedSize; - - // // constexpr index_t K2 = 16; // 16 bytes - // // constexpr index_t K1 = 128 / K2; // 8 - // // constexpr index_t K0 = KPerBlock / (K1 * K2 * BPackedSize); // KPerBlock/256/packsize - - // // constexpr index_t N2 = get_warp_size() / K1; // 8 - // // constexpr index_t N1 = BlockSize / get_warp_size(); // 4 - // // constexpr index_t N0 = NPerBlock / (N2 * N1); - - // // static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!"); - // // static_assert(K0 * K1 * K2 * BPackedSize == KPerBlock, - // // "K0, K1, K2 must cover whole KPerBlock!"); - - // // return make_static_tile_distribution( - // // tile_distribution_encoding< // - // // sequence<1>, - // // tuple, sequence>, // ?,4,8 1,8,32 or 2,8,16 - // // tuple, sequence<1, 2>>, // M1 M2,K1 - // // tuple, sequence<2, 1>>, - // // sequence<1, 2, 2>, // N0,K0,K2 - // // sequence<0, 0, 2>>{}); - // constexpr index_t BlockSize = Problem::kBlockSize; - // constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - // /// NOTE: for flatmm style byte tensor, divide KPerBlock by BPackedSize to get STORAGE dimensions - // // using BsDataType = remove_cvref_t; - // // using BDataType = remove_cvref_t{}, BsDataType>>; - // // constexpr index_t BPackedSize = numeric_traits>::PackedSize; - // // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; // Use STORAGE dimensions - // /// NOTE: use original KPerBlock - // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - // constexpr index_t VecLoadSize = GetVectorSizeB(); - // constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - - // using BLayout = remove_cvref_t< - // std::tuple_element_t{}, remove_cvref_t>>; - - - // if constexpr(std::is_same_v) - // { - // static_assert(false, "Not implemented"); - // } - // else - // { - // using TileEncodingPattern = - // tile_distribution_encoding_pattern_2d; - // return TileEncodingPattern::make_2d_static_tile_distribution(); - // } - // } - - // template - // CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDSBytes_TileDistribution() - // { - // // static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1"); - // using AsDataType = remove_cvref_t; - // using ADataType = remove_cvref_t{}, AsDataType>>; - // constexpr index_t APackedSize = numeric_traits>::PackedSize; - // using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - // constexpr index_t MWarps = BlockWarps::at(number<0>{}); - // constexpr index_t NWarps = BlockWarps::at(number<1>{}); - // constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); - // // constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); - // constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2); - // constexpr index_t K_Lane = get_warp_size() / 16; // 4 - // constexpr index_t K_Thread = KPerXdl / K_Lane; // 32 - // constexpr index_t DWORDx4 = 16; - // constexpr index_t AK1 = DWORDx4 * APackedSize; - - // if constexpr(K_Thread == AK1) - // return make_static_tile_distribution( - // tile_distribution_encoding< // - // sequence, - // tuple, sequence>, - // tuple, sequence<2, 1>>, - // tuple, sequence<0, 2>>, - // sequence<2>, - // sequence<1>>{}); - // else - // return make_static_tile_distribution( - // tile_distribution_encoding< // - // sequence, - // tuple, - // sequence>, - // tuple, sequence<2, 1>>, - // tuple, sequence<1, 2>>, - // sequence<2, 2>, - // sequence<0, 2>>{}); - // } - - // template - // CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLDSBytes_TileDistribution() - // { - // // static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1"); - // using BsDataType = remove_cvref_t; - // using BDataType = remove_cvref_t{}, BsDataType>>; - // constexpr index_t BPackedSize = numeric_traits>::PackedSize; - // using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - // constexpr index_t MWarps = BlockWarps::at(number<0>{}); - // constexpr index_t NWarps = BlockWarps::at(number<1>{}); - // // constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); - // constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); - // constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2); - // constexpr index_t K_Lane = get_warp_size() / 16; // 4 - // constexpr index_t K_Thread = KPerXdl / K_Lane; // 32 - // constexpr index_t DWORDx4 = 16; - // constexpr index_t BK1 = DWORDx4 * BPackedSize; - - // if constexpr(K_Thread == BK1) - // return make_static_tile_distribution( - // tile_distribution_encoding< // - // sequence, - // tuple, sequence>, - // tuple, sequence<2, 1>>, - // tuple, sequence<0, 2>>, - // sequence<2>, - // sequence<1>>{}); - // else - // return make_static_tile_distribution( - // tile_distribution_encoding< // - // sequence, - // tuple, - // sequence>, - // tuple, sequence<2, 1>>, - // tuple, sequence<1, 2>>, - // sequence<2, 2>, - // sequence<0, 2>>{}); - // } - template > CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() @@ -413,8 +198,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy return BlockGemmARegBRegCRegV1{}; } - // MX Scale tile distributions for loading from global memory - // Using the proven "Flat" patterns from v1 policy + // MX Scale tile distributions for loading from global memory template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() { @@ -425,20 +209,23 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy constexpr index_t MWarp = BlockWarps::at(number<0>{}); constexpr index_t NWarp = BlockWarps::at(number<1>{}); constexpr index_t MPerXdl = WarpTile::at(number<0>{}); - constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 4 for 16x16 mfma + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t ScaleKDimPerBlock = KPerBlock / BlockScaleSize / KXdlPack; // int32s per block + constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 64/16 = 4 threads in K dimension + // constexpr index_t KPackedElementsPerThread = ScaleKDimPerBlock / K_Lane; // 4/4 = 1 for K=512 - // Scale A: [MWarp * MPerXdl, K/32/KXdlPack] for warp-level tile - // Distribution: simple 2D for loading int32 packed scales - // TODO: check which layout to actually use (could use KxN) + // Scale A: [MWarp * MPerXdl, ScaleKDimPerBlock] warp-level tile + // For K=512: [16, 4], distribute 4 int32s across 4 K_Lane threads (1 each) + // Strided packing: thread at K_lane=k gets one int32 with scales for all kIters at K position k + // Distribution: Replicate in M dimension, distribute in K dimension (no vectorization - scalar loads) return make_static_tile_distribution( - tile_distribution_encoding, // repeat over NWarps - tuple, // M dimension - sequence>, // K dimension (int32 vec load) - tuple, sequence<2, 1>>, // which direction - tuple, sequence<0, 1>>, // which index - // - sequence<2>, - sequence<1>>{}); + tile_distribution_encoding, // repeat over NWarps + tuple, // M dimension + sequence>, // K dimension + tuple, sequence<2, 1>>, // , + tuple, sequence<1, 1>>, + sequence<2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock + sequence<0>>{}); } template @@ -451,20 +238,23 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy constexpr index_t MWarp = BlockWarps::at(number<0>{}); constexpr index_t NWarp = BlockWarps::at(number<1>{}); constexpr index_t NPerXdl = WarpTile::at(number<1>{}); - constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 4 for 16x16 mfma + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t ScaleKDimPerBlock = KPerBlock / BlockScaleSize / KXdlPack; // int32s per block + constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 64/16 = 4 threads in K dimension + // constexpr index_t KPackedElementsPerThread = ScaleKDimPerBlock / K_Lane; // 4/4 = 1 for K=512 - // Scale B: [K/32/KXdlPack, NWarp * NPerXdl] for warp-level tile - // Layout is [K, N] where K is packed int32 - // TODO: check which layout to actually use (could use KxN) + // Scale B: [ScaleKDimPerBlock, NWarp * NPerXdl] warp-level tile + // For K=512: [4, 64], distribute 4 int32s across 4 K_Lane threads (1 each) + // Strided packing: thread at K_lane=k gets one int32 with scales for all kIters at K position k + // Distribution: Distribute in K dimension (no vectorization - scalar loads), replicate in N dimension return make_static_tile_distribution( - tile_distribution_encoding, // repeat over MWarps - tuple, // K dimension (int32 vec load) - sequence>, // N dimension - tuple, sequence<0, 1>>, // which direction - tuple, sequence<0, 0>>, // which index - // - sequence<1>, - sequence<1>>{}); + tile_distribution_encoding, // repeat over MWarps + tuple, // N dimension + sequence>, // K dimension + tuple, sequence<2, 1>>, // , + tuple, sequence<1, 1>>, + sequence<2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock + sequence<0>>{}); } }; } // namespace ck_tile