diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index d1f395a045..c4f9c2afda 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -272,7 +272,10 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, if(!preshuffle && GemmConfig::UseStructuredSparsity) { - ck_tile::AdjustToStructuredSparsity{}(a_m_k); + if constexpr(GemmConfig::UseStructuredSparsity) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } } ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); diff --git a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc index 2fa3576a4c..a4f67ecaff 100644 --- a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc +++ b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc @@ -14,7 +14,56 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const float max_accumulated_v return ck_tile::make_tuple(rtol, atol); } -// Use e8m0_t directly without packing - simpler and cleaner approach +// Pack [MN, K/32] e8m0_t scales into [MN/MNPack, K/32/KPack] int32_t +// Each int32_t contains MNPack * KPack e8m0_t values with byte layout matching +// the GPU tile distribution: values are XdlMNThread apart in M and XdlKThread apart in K. +// byte[ik * MNPack + imn] = e8m0 at strided (mn, k) position +// kLast=true for A scales (layout [M, K/32]), kLast=false for B scales (layout [K/32, N]) +template +auto packScalesMNxK(const ck_tile::HostTensor& src, bool kLast) +{ + auto src_lengths = src.get_lengths(); + const ck_tile::index_t MN = kLast ? src_lengths[0] : src_lengths[1]; + const ck_tile::index_t K_scale = kLast ? src_lengths[1] : src_lengths[0]; + const ck_tile::index_t MN_packed = MN / MNPack; + const ck_tile::index_t K_packed = K_scale / KPack; + const ck_tile::index_t total_packed = MN_packed * K_packed; + + // Output as flat vector of int32_t (row-major [MN/MNPack, K/32/KPack]) + std::vector packed(total_packed); + + for(ck_tile::index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++) + { + for(ck_tile::index_t packed_k = 0; packed_k < K_packed; packed_k++) + { + int32_t val = 0; + ck_tile::index_t mn_lane = packed_mn % XdlMNThread; + ck_tile::index_t mn_group = packed_mn / XdlMNThread; + ck_tile::index_t k_lane = packed_k % XdlKThread; + ck_tile::index_t k_group = packed_k / XdlKThread; + for(ck_tile::index_t ik = 0; ik < KPack; ik++) + { + for(ck_tile::index_t imn = 0; imn < MNPack; imn++) + { + ck_tile::index_t byteIdx = ik * MNPack + imn; + ck_tile::index_t orig_mn = + mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane; + ck_tile::index_t orig_k = + k_group * XdlKThread * KPack + ik * XdlKThread + k_lane; + + ck_tile::e8m0_t v = kLast ? src(orig_mn, orig_k) : src(orig_k, orig_mn); + val |= (static_cast(v.get()) << (byteIdx * 8)); + } + } + packed[packed_mn * K_packed + packed_k] = val; + } + } + return packed; +} + template = 2 && MIterPerWarp_ % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t NXdlPackEff = (NIterPerWarp_ >= 2 && NIterPerWarp_ % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t KXdlPackEff = (KIterPerWarp_ >= 2 && KIterPerWarp_ % 2 == 0) ? 2 : 1; + + // Pack scales: [M, K/32] e8m0_t → [M/MXdlPackEff, K/32/KXdlPackEff] int32_t + // Original unpacked tensors are kept for CPU reference validation + constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile; + constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; + + auto scale_a_packed = + packScalesMNxK(scale_a_host, true); + auto scale_b_packed = + packScalesMNxK(scale_b_host, false); + + // Device buffers for A, B, C, and packed scale tensors ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_a_dev_buf(scale_a_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_b_dev_buf(scale_b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem scale_a_dev_buf(scale_a_packed.size() * sizeof(int32_t)); + ck_tile::DeviceMem scale_b_dev_buf(scale_b_packed.size() * sizeof(int32_t)); a_dev_buf.ToDevice(a_host.data()); b_dev_buf.ToDevice(b_host.data()); - c_dev_buf.SetZero(); // Initialize C buffer to zero - scale_a_dev_buf.ToDevice(scale_a_host.data()); - scale_b_dev_buf.ToDevice(scale_b_host.data()); + c_dev_buf.SetZero(); + scale_a_dev_buf.ToDevice(scale_a_packed.data()); + scale_b_dev_buf.ToDevice(scale_b_packed.data()); - // Scale pointers - use e8m0_t* directly - using ScaleM = ck_tile::MXScalePointer; // in blocks of 32 in K + // Scale pointers - point to packed int32_t data, kernel reinterprets as int32_t* + using ScaleM = ck_tile::MXScalePointer; using ScaleN = ck_tile::MXScalePointer; ScaleM scale_m(reinterpret_cast(scale_a_dev_buf.GetDeviceBuffer())); ScaleN scale_n(reinterpret_cast(scale_b_dev_buf.GetDeviceBuffer())); diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index 1a7b1f4c96..108afd9b1c 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -249,14 +249,19 @@ struct BlockGemmARegBRegCRegV1 }); } - // C += A * B with MX scaling - // ScaleATensor: [MIterPerWarp, KIterPerWarp] -> int32_t - // ScaleBTensor: [NIterPerWarp, KIterPerWarp] -> int32_t + // C += A * B with MX scaling and packed-in-two (XdlPack) optimization + // Scale tensors contain pre-packed int32_t: each int32_t holds MXdlPack * KXdlPack e8m0_t + // values (for A) or NXdlPack * KXdlPack (for B), packed on the host. + // Uses OpSel (0-3) to select which byte within the packed int32_t for each MFMA call. + // XdlPack template parameters default to 2; fall back to 1 when iteration count is too small. template + typename ScaleBTensor, + index_t MXdlPack_ = 2, + index_t NXdlPack_ = 2, + index_t KXdlPack_ = 2> CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, const ABlockTensor& a_block_tensor, const BBlockTensor& b_block_tensor, @@ -304,53 +309,88 @@ struct BlockGemmARegBRegCRegV1 constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - // hot loop with MX scaling: - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A Block window - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + // Effective XdlPack: fall back to 1 when iteration count is insufficient + constexpr index_t MXdlPack = + (MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1; + constexpr index_t NXdlPack = + (NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1; + constexpr index_t KXdlPack = + (KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1; - // get A scale for this M-K tile using get_y_sliced_thread_data + constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; + constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; + constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; + + // hot loop with MX scaling and pre-packed int32_t scales: + // Outer loops iterate over pack groups (scale tile indices) + static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) { + // Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t) auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const auto a_scale_e8m0 = scale_a_slice[number<0>{}]; - const int32_t a_scale = static_cast(a_scale_e8m0.get()); + sequence{}, sequence<1, 1, 1>{}); + const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // get B scale for this N-K tile using get_y_sliced_thread_data + static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { + // Get pre-packed int32_t B scale auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const auto b_scale_e8m0 = scale_b_slice[number<0>{}]; - const int32_t b_scale = static_cast(b_scale_e8m0.get()); + sequence{}, sequence<1, 1, 1>{}); + const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); - // read C warp tensor from C block tensor - using c_iter_idx = std:: - conditional_t, sequence>; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + // Inner loops: issue MFMAs within the pack group using OpSel + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + constexpr auto kIter = ikpack * KXdlPack + ikxdl; + constexpr auto mIter = impack * MXdlPack + imxdl; - // warp GEMM with MX scaling - // Cast e8m0_t to int32_t, use OpSel=0 (least significant byte) - constexpr index_t kOpSel = 0; // Always use OpSel=0 - WarpGemm{}.template operator()( - c_warp_tensor, a_warp_tensor, b_warp_tensor, a_scale, b_scale); + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = + a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); + // OpSel for A: selects byte within packed int32_t + constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; + + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto nIter = inpack * NXdlPack + inxdl; + + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, + b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // OpSel for B: selects byte within packed int32_t + constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; + + // read C warp tensor from C block tensor + using c_iter_idx = std::conditional_t, + sequence>; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = + c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM with MX scaling using pre-packed scale and OpSel + WarpGemm{}.template operator()(c_warp_tensor, + a_warp_tensor, + b_warp_tensor, + a_scale_packed, + b_scale_packed); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); }); }); }); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp index a9f6dced9d..2f2a67deae 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp @@ -118,7 +118,12 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return Policy::template GetSmemSize(); + // We are not storing the original packed type in LDS, so we need to multiply the smem size + // by the packed size. + constexpr index_t smem_size_a = Policy::template GetSmemSizeA() * APackedSize; + constexpr index_t smem_size_b = Policy::template GetSmemSizeB() * BPackedSize; + + return 2 * (smem_size_a + smem_size_b); } static constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp; diff --git a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp index 57278a504a..a6428b88ac 100644 --- a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp +++ b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp @@ -98,6 +98,30 @@ struct MXGemmKernel : UniversalGemmKernel::PackedSize; static constexpr auto BPackedSize = numeric_traits::PackedSize; + // XdlPack: desired packing of e8m0_t scale values into int32_t + static constexpr index_t MXdlPack = 2; + static constexpr index_t NXdlPack = 2; + static constexpr index_t KXdlPack = 2; + + // Effective pack sizes: fall back to 1 when dimension is too small + using BlockWarps_ = typename BlockGemmShape::BlockWarps; + static constexpr index_t MPerBlock_ = BlockGemmShape::kM; + static constexpr index_t NPerBlock_ = BlockGemmShape::kN; + static constexpr index_t KPerBlock_ = BlockGemmShape::kK; + static constexpr index_t MWarp_ = BlockWarps_::at(number<0>{}); + static constexpr index_t NWarp_ = BlockWarps_::at(number<1>{}); + static constexpr index_t KPerXdl_ = BlockGemmShape::WarpTile::at(number<2>{}); + static constexpr index_t MIterPerWarp_ = MPerBlock_ / (MWarp_ * MThreadPerXdl); + static constexpr index_t NIterPerWarp_ = NPerBlock_ / (NWarp_ * NThreadPerXdl); + static constexpr index_t KIterPerWarp_ = KPerBlock_ / KPerXdl_; + + static constexpr index_t MXdlPackEff = + (MIterPerWarp_ >= MXdlPack && MIterPerWarp_ % MXdlPack == 0) ? MXdlPack : 1; + static constexpr index_t NXdlPackEff = + (NIterPerWarp_ >= NXdlPack && NIterPerWarp_ % NXdlPack == 0) ? NXdlPack : 1; + static constexpr index_t KXdlPackEff = + (KIterPerWarp_ >= KXdlPack && KIterPerWarp_ % KXdlPack == 0) ? KXdlPack : 1; + static constexpr int kBlockPerCu = 1; static_assert(DsLayout::size() == DsDataType::size(), @@ -245,7 +269,9 @@ struct MXGemmKernel : UniversalGemmKernel CK_TILE_DEVICE static auto MakeScaleABlockWindows(const KernelArgs& kargs, const index_t i_m) @@ -253,28 +279,28 @@ struct MXGemmKernel : UniversalGemmKernel( - reinterpret_cast(scale_a.ptr), - make_tuple(kargs.M, scale_k_size), - make_tuple(scale_k_size, 1)); + reinterpret_cast(scale_a.ptr), + make_tuple(scale_m_packed, scale_k_packed), + make_tuple(scale_k_packed, 1)); - // Create block window for scale A - // K dimension: scale_k_size e8m0_t elements - // i_m is element offset (iM * MPerBlock), not tile index - auto scale_a_block_window = - make_tile_window(scale_a_tensor_view, - make_tuple(number{}, - number{}), - {i_m, 0}); + // Tile window shape: [MPerBlock/MXdlPackEff, KPerBlock/32/KXdlPackEff] + auto scale_a_block_window = make_tile_window( + scale_a_tensor_view, + make_tuple(number{}, + number{}), + {i_m / MXdlPackEff, 0}); return scale_a_block_window; } - // Create scale B block windows following the pattern of MakeBBlockWindows + // Create scale B block windows with packed int32_t layout + // Host packs 2N x 2K e8m0_t values into one int32_t + // Tensor view: [N/NXdlPack, K/32/KXdlPack] of int32_t template CK_TILE_DEVICE static auto MakeScaleBBlockWindows(const KernelArgs& kargs, const index_t i_n) @@ -282,23 +308,21 @@ struct MXGemmKernel : UniversalGemmKernel( - reinterpret_cast(scale_b.ptr), - make_tuple(kargs.N, scale_k_size), // [N, K/32] for access - make_tuple(scale_k_size, 1)); // stride to match col-major storage + reinterpret_cast(scale_b.ptr), + make_tuple(scale_n_packed, scale_k_packed), + make_tuple(scale_k_packed, 1)); - // Create block window for scale B - // Tile window shape matches access pattern: [NPerBlock, KPerBlock/32] - // i_n is element offset (iN * NPerBlock) - auto scale_b_block_window = - make_tile_window(scale_b_tensor_view, - make_tuple(number{}, - number{}), - {i_n, 0}); + // Tile window shape: [NPerBlock/NXdlPackEff, KPerBlock/32/KXdlPackEff] + auto scale_b_block_window = make_tile_window( + scale_b_tensor_view, + make_tuple(number{}, + number{}), + {i_n / NXdlPackEff, 0}); return scale_b_block_window; } diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 88058ac2ac..ec16a4e8b6 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -315,14 +315,36 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< }, number{}); - ////////////// MX Scale windows ///////////////// + ////////////// MX Scale windows (pre-packed int32_t) ///////////////// // Get WarpGemm configuration using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; constexpr index_t MWarp = BlockWarps::at(I0{}); constexpr index_t NWarp = BlockWarps::at(I1{}); - // Calculate scale dimensions: KPerBlock elements need KPerBlock/32 e8m0_t scales - constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize; + // Compute effective XdlPack sizes (fall back to 1 when iter count < pack) + constexpr index_t MPerXdl = WarpTile::at(I0{}); + constexpr index_t NPerXdl = WarpTile::at(I1{}); + constexpr index_t KPerXdl = WarpTile::at(I2{}); + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); + constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; + + constexpr index_t MXdlPackEff = + (MIterPerWarp >= Policy::MXdlPack && MIterPerWarp % Policy::MXdlPack == 0) + ? Policy::MXdlPack + : 1; + constexpr index_t NXdlPackEff = + (NIterPerWarp >= Policy::NXdlPack && NIterPerWarp % Policy::NXdlPack == 0) + ? Policy::NXdlPack + : 1; + constexpr index_t KXdlPackEff = + (KIterPerWarp >= Policy::KXdlPack && KIterPerWarp % Policy::KXdlPack == 0) + ? Policy::KXdlPack + : 1; + + // Packed scale dimensions + constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize / KXdlPackEff; // Scale tensor views and base origins for creating tile windows per iteration const auto& scale_a_tensor_view = scale_a_window.get_bottom_tensor_view(); @@ -330,18 +352,18 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< auto scale_a_base_origin = scale_a_window.get_window_origin(); auto scale_b_base_origin = scale_b_window.get_window_origin(); - // Create sample scale windows to determine tile types - auto scale_a_dram_window = - make_tile_window(scale_a_tensor_view, - make_tuple(number{}, number{}), - scale_a_base_origin, - Policy::template MakeMX_ScaleA_DramTileDistribution()); + // Create scale windows with packed int32_t dimensions + auto scale_a_dram_window = make_tile_window( + scale_a_tensor_view, + make_tuple(number{}, number{}), + scale_a_base_origin, + Policy::template MakeMX_ScaleA_DramTileDistribution()); - auto scale_b_dram_window = - make_tile_window(scale_b_tensor_view, - make_tuple(number{}, number{}), - scale_b_base_origin, - Policy::template MakeMX_ScaleB_DramTileDistribution()); + auto scale_b_dram_window = make_tile_window( + scale_b_tensor_view, + make_tuple(number{}, number{}), + scale_b_base_origin, + Policy::template MakeMX_ScaleB_DramTileDistribution()); // this pipeline has a pair of LDS buffers per logical tile auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); @@ -427,8 +449,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< "SmemSizeB size is wrong!"); ////////////// MX Scale register tiles (ping-pong buffers) ///////////////// - // No packing needed - each thread gets e8m0_t elements directly - // Each thread will cast e8m0_t to int32_t for WarpGemm with OpSel=0 + // Scales are pre-packed int32_t: each int32_t holds 2M/N x 2K e8m0_t values + // Block GEMM uses OpSel (0-3) to select the right byte per MFMA call using ScaleATileType = decltype(load_tile(scale_a_dram_window)); using ScaleBTileType = decltype(load_tile(scale_b_dram_window)); diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index f9f1794b10..d90271d235 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -131,7 +131,15 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy return BlockGemmARegBRegCRegV1{}; } - // MX Scale tile distributions for loading from global memory + // XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension + // Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales + // Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales + static constexpr int MXdlPack = 2; + static constexpr int NXdlPack = 2; + static constexpr int KXdlPack = 2; + + // MX Scale tile distributions for loading pre-packed int32_t from global memory + // Packed layout: [M/MXdlPack, K/32/KXdlPack] of int32_t template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() { @@ -145,21 +153,29 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy constexpr index_t MPerXdl = WarpTile::at(number<0>{}); constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 64/16 = 4 threads in K dimension + constexpr index_t K_Lane = get_warp_size() / MPerXdl; constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); constexpr index_t KPerXdl = WarpTile::at(number<2>{}); constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; constexpr index_t KPerLane = KPerXdl / BlockScaleSize / K_Lane; + // Effective pack sizes: fall back to 1 when iteration count < pack size + constexpr index_t MXdlPackEff = + (MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1; + constexpr index_t KXdlPackEff = + (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; + + constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff; + constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; + return make_static_tile_distribution( - tile_distribution_encoding< - sequence, // repeat over MWarps - tuple, // M dimension (first) - sequence>, // K dimension (second) - tuple, sequence<2, 1>>, // , - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, // - sequence<0, 0, 2>>{}); + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<2, 1, 2>, + sequence<0, 0, 2>>{}); } template @@ -169,27 +185,35 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy using BlockWarps = typename BlockGemmShape::BlockWarps; using WarpTile = typename BlockGemmShape::WarpTile; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t MWarp = BlockWarps::at(number<0>{}); - constexpr index_t NWarp = BlockWarps::at(number<1>{}); - constexpr index_t NPerXdl = WarpTile::at(number<1>{}); - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 64/16 = 4 threads in K dimension + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t MWarp = BlockWarps::at(number<0>{}); + constexpr index_t NWarp = BlockWarps::at(number<1>{}); + constexpr index_t NPerXdl = WarpTile::at(number<1>{}); + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t K_Lane = get_warp_size() / NPerXdl; constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); constexpr index_t KPerXdl = WarpTile::at(number<2>{}); constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; constexpr index_t KPerLane = KPerXdl / BlockScaleSize / K_Lane; + // Effective pack sizes: fall back to 1 when iteration count < pack size + constexpr index_t NXdlPackEff = + (NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1; + constexpr index_t KXdlPackEff = + (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; + + constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff; + constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; + return make_static_tile_distribution( - tile_distribution_encoding< - sequence, // repeat over MWarps - tuple, // N dimension (first) - sequence>, // K dimension (second) - tuple, sequence<2, 1>>, // , - tuple, sequence<1, 2>>, - sequence<2, 1, 2>, // - sequence<0, 0, 2>>{}); + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<2, 1, 2>, + sequence<0, 0, 2>>{}); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves.hpp index 23ad2dd12a..e6249ffa4f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves.hpp @@ -134,7 +134,12 @@ struct ABQuantGemmPipelineAgBgCrEightWaves : public BaseGemmPipelineAgBgCrCompV3 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return Policy::template GetSmemSize(); + // We are not storing the original packed type in LDS, so we need to multiply the smem size + // by the packed size. + constexpr index_t smem_size_a = Policy::template GetSmemSizeA() * APackedSize; + constexpr index_t smem_size_b = Policy::template GetSmemSizeB() * BPackedSize; + + return 2 * (smem_size_a + smem_size_b); } CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrEightWaves\n"; } diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp index 9019366319..6e7ddfb5d0 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp @@ -45,6 +45,55 @@ class TestMxGemmUtil : public ::testing::Test using ScaleM = ck_tile::MXScalePointer; using ScaleN = ck_tile::MXScalePointer; + // Pack [MN, K/32] e8m0_t scales into [MN/MNPack, K/32/KPack] int32_t + // Each int32_t contains MNPack * KPack e8m0_t values with byte layout matching + // the GPU tile distribution: values are XdlMNThread apart in M and XdlKThread apart in K. + // byte[ik * MNPack + imn] = e8m0 at strided (mn, k) position + // kLast=true for A scales (layout [M, K/32]), kLast=false for B scales (layout [K/32, N]) + template + static auto packScalesMNxK(const ck_tile::HostTensor& src, bool kLast) + { + auto src_lengths = src.get_lengths(); + const ck_tile::index_t MN = kLast ? src_lengths[0] : src_lengths[1]; + const ck_tile::index_t K_scale = kLast ? src_lengths[1] : src_lengths[0]; + const ck_tile::index_t MN_packed = MN / MNPack; + const ck_tile::index_t K_packed = K_scale / KPack; + const ck_tile::index_t total_packed = MN_packed * K_packed; + + std::vector packed(total_packed); + + for(ck_tile::index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++) + { + for(ck_tile::index_t packed_k = 0; packed_k < K_packed; packed_k++) + { + int32_t val = 0; + ck_tile::index_t mn_lane = packed_mn % XdlMNThread; + ck_tile::index_t mn_group = packed_mn / XdlMNThread; + ck_tile::index_t k_lane = packed_k % XdlKThread; + ck_tile::index_t k_group = packed_k / XdlKThread; + for(ck_tile::index_t ik = 0; ik < KPack; ik++) + { + for(ck_tile::index_t imn = 0; imn < MNPack; imn++) + { + ck_tile::index_t byteIdx = ik * MNPack + imn; + ck_tile::index_t orig_mn = + mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane; + ck_tile::index_t orig_k = + k_group * XdlKThread * KPack + ik * XdlKThread + k_lane; + + ck_tile::e8m0_t v = kLast ? src(orig_mn, orig_k) : src(orig_k, orig_mn); + val |= (static_cast(v.get()) << (byteIdx * 8)); + } + } + packed[packed_mn * K_packed + packed_k] = val; + } + } + return packed; + } + void Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, int seed = 1234) { const ck_tile::index_t scale_k_size = K / 32; @@ -75,17 +124,43 @@ class TestMxGemmUtil : public ::testing::Test ck_tile::FillUniformDistribution{0.001f, 10.f, seed++}(scale_a_host); ck_tile::FillUniformDistribution{0.001f, 10.f, seed++}(scale_b_host); + // Compute effective XdlPack sizes based on GemmConfig tile dimensions + constexpr ck_tile::index_t MPerXdl = GemmConfig::M_Warp_Tile; + constexpr ck_tile::index_t NPerXdl = GemmConfig::N_Warp_Tile; + constexpr ck_tile::index_t KPerXdl = GemmConfig::K_Warp_Tile; + constexpr ck_tile::index_t MIterPerWarp = + GemmConfig::M_Tile / (GemmConfig::M_Warp * MPerXdl); + constexpr ck_tile::index_t NIterPerWarp = + GemmConfig::N_Tile / (GemmConfig::N_Warp * NPerXdl); + constexpr ck_tile::index_t KIterPerWarp = GemmConfig::K_Tile / KPerXdl; + + constexpr ck_tile::index_t MXdlPackEff = + (MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t NXdlPackEff = + (NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1; + constexpr ck_tile::index_t KXdlPackEff = + (KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1; + + constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile; + constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread; + + // Pack scales into int32_t for GPU consumption + auto scale_a_packed = + packScalesMNxK(scale_a_host, true); + auto scale_b_packed = + packScalesMNxK(scale_b_host, false); + ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_a_dev_buf(scale_a_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem scale_b_dev_buf(scale_b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem scale_a_dev_buf(scale_a_packed.size() * sizeof(int32_t)); + ck_tile::DeviceMem scale_b_dev_buf(scale_b_packed.size() * sizeof(int32_t)); a_dev_buf.ToDevice(a_host.data()); b_dev_buf.ToDevice(b_host.data()); c_dev_buf.SetZero(); - scale_a_dev_buf.ToDevice(scale_a_host.data()); - scale_b_dev_buf.ToDevice(scale_b_host.data()); + scale_a_dev_buf.ToDevice(scale_a_packed.data()); + scale_b_dev_buf.ToDevice(scale_b_packed.data()); ScaleM scale_m(reinterpret_cast(scale_a_dev_buf.GetDeviceBuffer())); ScaleN scale_n(reinterpret_cast(scale_b_dev_buf.GetDeviceBuffer()));