diff --git a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc index dfdca51671..0e1e08cbef 100644 --- a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc +++ b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc @@ -59,15 +59,20 @@ auto pack_scales_for_k_dimension(const ScaleTensor& scale_unpacked, for(ck_tile::index_t k = 0; k < pack_size; ++k) { // Strided packing: byte k corresponds to kIter=k - // Stride by packed dimension (new_dim1 for dim1 packing, 1 for dim0 packing since it's linear) - // Wait, we need to map unpacked logical positions to correct strided pattern - // For K=512: 16 unpacked elements [0-15] map to 4 int32s strided: - // int32[0] = {elem[0], elem[4], elem[8], elem[12]} (bytes 0,1,2,3 for kIter 0,1,2,3) - // int32[1] = {elem[1], elem[5], elem[9], elem[13]} - // ... - // So: packed_index j (or i), byte position k -> unpacked_index = j/i + k * packed_size - ck_tile::index_t src_i = pack_dim1 ? i : (i + k * packed_k_dim); - ck_tile::index_t src_j = pack_dim1 ? (j + k * packed_k_dim) : j; + // The stride is always pack_size (4), not packed_k_dim! + // For K=512: 16 unpacked elements [0-15] -> 4 packed int32s + // int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (stride=4) + // int32[1] = {unpacked[1], unpacked[5], unpacked[9], unpacked[13]} (stride=4) + // For K=1024: 32 unpacked elements [0-31] -> 8 packed int32s + // int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (stride=4) + // int32[4] = {unpacked[16], unpacked[20], unpacked[24], unpacked[28]} (stride=4) + // For row-major (pack_dim1=true): packed index j, byte k -> unpacked[j + k*4] + // For col-major (pack_dim1=false): packed index i, byte k -> unpacked[i*4 + k*4] = unpacked[(i + k)*4] + // But we want: packed index i, byte k -> unpacked[i*4 + k] (base i*4, then stride 4) + // Actually: int32[i] should pack {unpacked[i*4 + 0*4], unpacked[i*4 + 1*4], unpacked[i*4 + 2*4], unpacked[i*4 + 3*4]} + // = {unpacked[i*4], unpacked[i*4 + 4], unpacked[i*4 + 8], unpacked[i*4 + 12]} + ck_tile::index_t src_i = pack_dim1 ? i : (i * pack_size + k * pack_size); + ck_tile::index_t src_j = pack_dim1 ? (j * pack_size + k * pack_size) : j; uint8_t scale_byte = *reinterpret_cast(&scale_unpacked(src_i, src_j)); packed_value |= (static_cast(scale_byte) << (k * 8)); @@ -140,13 +145,14 @@ int run_mx_gemm_with_layouts(int argc, ck_tile::host_tensor_descriptor(M, scale_k_size, stride_scale_a, is_row_major(ALayout{}))); ck_tile::HostTensor scale_b_host( ck_tile::host_tensor_descriptor(scale_k_size, N, stride_scale_b, is_row_major(BLayout{}))); + int seed = 1234; switch(init_method) { case 0: - ck_tile::FillUniformDistribution{-1.f, 1.f}(a_host); - ck_tile::FillUniformDistribution{-1.f, 1.f}(b_host); - ck_tile::FillUniformDistribution{1.f, 10.f}(scale_a_host); - ck_tile::FillUniformDistribution{1.f, 10.f}(scale_b_host); + ck_tile::FillUniformDistribution{-1.f, 1.f, seed++}(a_host); + ck_tile::FillUniformDistribution{-1.f, 1.f, seed++}(b_host); + ck_tile::FillUniformDistribution{0.001f, 10.f, seed++}(scale_a_host); + ck_tile::FillUniformDistribution{0.001f, 10.f, seed++}(scale_b_host); break; case 1: ck_tile::FillConstant{ADataType(1.f)}(a_host); @@ -155,11 +161,82 @@ int run_mx_gemm_with_layouts(int argc, ck_tile::FillConstant{ScaleType(1.f)}(scale_b_host); break; case 2: - ck_tile::FillUniformDistribution{-1.f, 1.f}(a_host); - ck_tile::FillUniformDistribution{-1.f, 1.f}(b_host); + ck_tile::FillUniformDistribution{-1.f, 1.f, seed++}(a_host); + ck_tile::FillUniformDistribution{-1.f, 1.f, seed++}(b_host); ck_tile::FillConstant{ScaleType(1.f)}(scale_a_host); ck_tile::FillConstant{ScaleType(1.f)}(scale_b_host); break; + case 3: + // Debug mode: simple power-of-2 pattern for scales (e8m0 format) + ck_tile::FillConstant{ADataType(1.f)}(a_host); + ck_tile::FillConstant{BDataType(1.f)}(b_host); + // Fill scales with power-of-2 pattern: 1.0, 2.0, 4.0, 8.0, 16.0, ... + // e8m0 is exponent-only, so these give clear distinct values + // for(std::size_t i = 0; i < scale_a_host.mDesc.get_element_space_size(); ++i) + // { + // float val = std::pow(2.0f, static_cast(i % 16)); // cycle through 2^0 to 2^15 + // scale_a_host.mData[i] = ScaleType(val); + // } + // for(std::size_t i = 0; i < scale_b_host.mDesc.get_element_space_size(); ++i) + // { + // float val = std::pow(2.0f, static_cast(i % 16)); // cycle through 2^0 to 2^15 + // scale_b_host.mData[i] = ScaleType(val); + // } + ck_tile::FillConstant{ScaleType(1.f)}(scale_a_host); + ck_tile::FillConstant{ScaleType(1.f)}(scale_b_host); + + // Test data to verify K block loading for K=1024 (2 K blocks) + // K block 0: K indices 0-511, scale K indices 0-15, packed into K_packed indices 0-3 + // K block 1: K indices 512-1023, scale K indices 16-31, packed into K_packed indices 4-7 + + // Scale A: [M, K/32] row-major (unpacked K indices in second dim) + // Strided packing: int32[j] packs unpacked[j], unpacked[j+4], unpacked[j+8], unpacked[j+12] + // K block 0: K indices 0-511 → unpacked K indices 0-15 → packed int32s 0-3 + // int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (K_Lane=0) + // int32[1] = {unpacked[1], unpacked[5], unpacked[9], unpacked[13]} (K_Lane=1) + scale_a_host(0, 0) = ScaleType(2.f); // K block 0, int32[0] byte 0 (unpacked[0]) + scale_a_host(0, 4) = ScaleType(4.f); // K block 0, int32[0] byte 1 (unpacked[4]) + scale_a_host(0, 8) = ScaleType(8.f); // K block 0, int32[0] byte 2 (unpacked[8]) + scale_a_host(0, 12) = ScaleType(16.f); // K block 0, int32[0] byte 3 (unpacked[12]) + scale_a_host(1, 0) = ScaleType(32.f); // K block 0, int32[1] byte 0 (unpacked[1]) + + // K block 1: K indices 512-1023 → unpacked K indices 16-31 → packed int32s 4-7 + // int32[4] = {unpacked[16], unpacked[20], unpacked[24], unpacked[28]} (K_Lane=0) + // int32[5] = {unpacked[17], unpacked[21], unpacked[25], unpacked[29]} (K_Lane=1) + scale_a_host(0, 16) = ScaleType(256.f); // K block 1, int32[4] byte 0 (unpacked[16]) + scale_a_host(0, 20) = ScaleType(512.f); // K block 1, int32[4] byte 1 (unpacked[20]) + scale_a_host(0, 24) = ScaleType(1024.f); // K block 1, int32[4] byte 2 (unpacked[24]) + scale_a_host(0, 28) = ScaleType(2048.f); // K block 1, int32[4] byte 3 (unpacked[28]) + scale_a_host(1, 16) = ScaleType(4096.f); // K block 1, int32[5] byte 0 (unpacked[17]) + + // mIter=1: M rows 16-31 (second XDL block) + scale_a_host(16, 0) = ScaleType(64.f); // K block 0, unpacked K=0, M=16 + scale_a_host(16, 16) = ScaleType(8192.f); // K block 1, unpacked K=16, M=16 + + // Scale B: [K/32, N] col-major (unpacked K indices in first dim, N in second dim) + // Strided packing: int32[i] packs unpacked[i], unpacked[i+8], unpacked[i+16], unpacked[i+24] + // K block 0: K indices 0-511 → unpacked K indices 0-15 → packed int32s 0-3 + // int32[0] = {unpacked[0], unpacked[4], unpacked[8], unpacked[12]} (K_Lane=0) + // int32[1] = {unpacked[1], unpacked[5], unpacked[9], unpacked[13]} (K_Lane=1) + scale_b_host(0, 0) = ScaleType(2.f); // K block 0, int32[0] byte 0 (unpacked[0]) + scale_b_host(4, 0) = ScaleType(4.f); // K block 0, int32[0] byte 1 (unpacked[4]) + scale_b_host(8, 0) = ScaleType(8.f); // K block 0, int32[0] byte 2 (unpacked[8]) + scale_b_host(12, 0) = ScaleType(16.f); // K block 0, int32[0] byte 3 (unpacked[12]) + scale_b_host(1, 0) = ScaleType(32.f); // K block 0, int32[1] byte 0 (unpacked[1]) + + // K block 1: K indices 512-1023 → unpacked K indices 16-31 → packed int32s 4-7 + // int32[4] = {unpacked[16], unpacked[20], unpacked[24], unpacked[28]} (K_Lane=0) + // int32[5] = {unpacked[17], unpacked[21], unpacked[25], unpacked[29]} (K_Lane=1) + scale_b_host(16, 0) = ScaleType(256.f); // K block 1, int32[4] byte 0 (unpacked[16]) + scale_b_host(20, 0) = ScaleType(512.f); // K block 1, int32[4] byte 1 (unpacked[20]) + scale_b_host(24, 0) = ScaleType(1024.f); // K block 1, int32[4] byte 2 (unpacked[24]) + scale_b_host(28, 0) = ScaleType(2048.f); // K block 1, int32[4] byte 3 (unpacked[28]) + scale_b_host(17, 0) = ScaleType(4096.f); // K block 1, int32[5] byte 0 (unpacked[17]) + + // nIter=1: N rows 16-31 (second XDL block) + scale_b_host(0, 16) = ScaleType(64.f); // K block 0, unpacked K=0, N=16 + scale_b_host(16, 16) = ScaleType(8192.f); // K block 1, unpacked K=16, N=16 + break; } // Pack scales: 4 consecutive e8m0_t in K dimension → 1 int32 for efficient 32-bit loads @@ -169,6 +246,48 @@ int run_mx_gemm_with_layouts(int argc, auto scale_a_packed = pack_scales_for_k_dimension(scale_a_host, 4); auto scale_b_packed = pack_scales_for_k_dimension(scale_b_host, 4); + // DEBUG: Print first few packed scale values + if (true ||init_method == 3) + { + std::cout << "Host: ScaleA packed [0,0]: "; + uint8_t* a_bytes = reinterpret_cast(&scale_a_packed(0, 0)); + std::cout << "[" << static_cast(a_bytes[0]) << "," << static_cast(a_bytes[1]) << "," + << static_cast(a_bytes[2]) << "," << static_cast(a_bytes[3]) << "]\n"; + std::cout << "Host: ScaleA packed [0,4]: "; + uint8_t* a_bytes4 = reinterpret_cast(&scale_a_packed(0, 4)); + std::cout << "[" << static_cast(a_bytes4[0]) << "," << static_cast(a_bytes4[1]) << "," + << static_cast(a_bytes4[2]) << "," << static_cast(a_bytes4[3]) << "]\n"; + std::cout << "Host: ScaleB packed [0,0]: "; + uint8_t* b_bytes = reinterpret_cast(&scale_b_packed(0, 0)); + std::cout << "[" << static_cast(b_bytes[0]) << "," << static_cast(b_bytes[1]) << "," + << static_cast(b_bytes[2]) << "," << static_cast(b_bytes[3]) << "]\n"; + std::cout << "Host: ScaleB packed [4,0]: "; + uint8_t* b_bytes4 = reinterpret_cast(&scale_b_packed(4, 0)); + std::cout << "[" << static_cast(b_bytes4[0]) << "," << static_cast(b_bytes4[1]) << "," + << static_cast(b_bytes4[2]) << "," << static_cast(b_bytes4[3]) << "]\n"; + + // Print unpacked first row/col for reference + std::cout << "Host: ScaleA unpacked thread 0, every 4th element: ["; + for (int k = 0; k < std::min(32, static_cast(scale_a_host.mDesc.get_lengths()[1])); k += 4) + std::cout << static_cast(*reinterpret_cast(&scale_a_host(0, k))) << ","; + std::cout << "]\n"; + std::cout << "Host: ScaleB unpacked thread 0, every 4th element: ["; + for (int k = 0; k < std::min(32, static_cast(scale_b_host.mDesc.get_lengths()[0])); k += 4) + std::cout << static_cast(*reinterpret_cast(&scale_b_host(k, 0))) << ","; + std::cout << "]\n"; + // Threads 0-15: M rows 0-15, K_Lane cycles through 0,1,2,3 + // Thread 16: M row 0 again, but next K_Lane group (K_Lane=1 if cycling, or next K group) + // Actually, thread 16 goes back to row 0 with a different K index + std::cout << "Host: ScaleA unpacked thread 16 (row 0, next K group), every 4th element: ["; + for (int k = 1; k < std::min(32, static_cast(scale_a_host.mDesc.get_lengths()[1])); k += 4) + std::cout << static_cast(*reinterpret_cast(&scale_a_host(0, k))) << ","; + std::cout << "]\n"; + std::cout << "Host: ScaleB unpacked thread 16 (row 0, next K group), every 4th element: ["; + for (int k = 1; k < std::min(32, static_cast(scale_b_host.mDesc.get_lengths()[0])); k += 4) + std::cout << static_cast(*reinterpret_cast(&scale_b_host(k, 0))) << ","; + std::cout << "]\n"; + } + ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes()); diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 09e60cecdb..2f77c9c8c4 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -410,8 +410,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< using WarpTile = typename BlockGemmShape::WarpTile; constexpr index_t MWarp = BlockWarps::at(I0{}); constexpr index_t NWarp = BlockWarps::at(I1{}); - constexpr index_t MPerXdl = WarpTile::at(I0{}); - constexpr index_t NPerXdl = WarpTile::at(I1{}); + // constexpr index_t MPerXdl = WarpTile::at(I0{}); + // constexpr index_t NPerXdl = WarpTile::at(I1{}); constexpr index_t ScaleBlockSize = 32; // Each scale covers 32 K elements @@ -427,43 +427,22 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< auto scale_b_base_origin = scale_b_window.get_window_origin(); // Create sample scale windows to determine tile types - auto scale_a_dram_window_sample = make_tile_window( + auto scale_a_dram_window = make_tile_window( scale_a_tensor_view, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), scale_a_base_origin, Policy::template MakeMX_ScaleA_DramTileDistribution()); - auto scale_b_dram_window_sample = make_tile_window( + auto scale_b_dram_window = make_tile_window( scale_b_tensor_view, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), scale_b_base_origin, Policy::template MakeMX_ScaleB_DramTileDistribution()); // this pipeline has a pair of LDS buffers per logical tile - // TODO: check for packed size - are these blocks too big? - /// NOTE: flatmm style byte tensor approach: - // auto&& [a_lds_block0, b_lds_block0] = Base::template GetABLdsTensorViews(p_smem_0); - // auto&& [a_lds_block1, b_lds_block1] = Base::template GetABLdsTensorViews(p_smem_1); - /// NOTE: with original fp4 types: auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); - // set up LDS tile shapes - always use STORAGE dimensions for K - /// NOTE: flatmm style byte tensor approach: - // constexpr auto a_lds_shape = []() { - // if constexpr(is_a_load_tr_v) - // return make_tuple(number{}, number{}); - // else - // return make_tuple(number{}, number{}); - // }(); - - // constexpr auto b_lds_shape = []() { - // if constexpr(is_b_load_tr_v) - // return make_tuple(number{}, number{}); - // else - // return make_tuple(number{}, number{}); - // }(); - /// NOTE: use original shapes constexpr auto a_lds_shape = []() { if constexpr(is_a_load_tr_v) return make_tuple(number{}, number{}); @@ -490,13 +469,6 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< // initialize DRAM window steps, used to advance the DRAM windows using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; - - /// NOTE: flatmm style way to calculate steps with packed size - // constexpr ADramTileWindowStep a_dram_tile_window_step = - // is_a_col_major ? make_array(KPerBlock / APackedSize, 0) : make_array(0, KPerBlock / APackedSize); - // constexpr BDramTileWindowStep b_dram_tile_window_step = - // is_b_row_major ? make_array(KPerBlock / BPackedSize, 0) : make_array(0, KPerBlock / BPackedSize); - /// NOTE: use original steps and assume that PackedSize is correctly applied elsewhere constexpr ADramTileWindowStep a_dram_tile_window_step = is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); constexpr BDramTileWindowStep b_dram_tile_window_step = @@ -509,10 +481,6 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< Base::GlobalPrefetchAsync( b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step); - // Initialize WarpGemm for MX scaling - // using WarpGemm = typename remove_cvref_t())>::WarpGemm; - // using CWarpTensor = typename WarpGemm::CWarpTensor; - // Initialize block gemm and C block tile auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); @@ -548,8 +516,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< // Calculate scale iterations for M/N dimensions constexpr index_t KPerXdl = WarpTile::at(I2{}); constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); - constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); + // constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); + // constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); // ScaleKPackedPerIter: number of int32s needed to cover all KIterPerWarp iterations // Each int32 packs 4 scales (via strided packing), OpSel selects byte for kIter @@ -557,58 +525,23 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< constexpr index_t ScaleKPackedPerIter = KIterPerWarp / KXdlPack; static_assert(ScaleKPackedPerIter > 0, "ScaleKPackedPerIter must be positive!"); - // Load a sample scale tile to get the type after distribution - auto scale_a_sample = load_tile(scale_a_dram_window_sample); - auto scale_b_sample = load_tile(scale_b_dram_window_sample); - - using ScaleTileElementA = remove_cvref_t; - using ScaleTileElementB = remove_cvref_t; - - // ScaleATileType: array of distributed tensors, one per M/N iteration - // Each distributed tensor holds ScaleKPackedPerIter int32 elements across threads - using ScaleATileType = statically_indexed_array; - using ScaleBTileType = statically_indexed_array; - + using ScaleATileType = decltype(load_tile(scale_a_dram_window)); + using ScaleBTileType = decltype(load_tile(scale_b_dram_window)); ScaleATileType scale_a_tile_ping, scale_a_tile_pong; ScaleBTileType scale_b_tile_ping, scale_b_tile_pong; + + // initialize Scale DRAM window steps, used to advance the Scale DRAM windows + using ScaleADramTileWindowStep = typename ScaleADramBlockWindowTmp::BottomTensorIndex; + using ScaleBDramTileWindowStep = typename ScaleBDramBlockWindowTmp::BottomTensorIndex; + constexpr ScaleADramTileWindowStep scale_a_dram_tile_window_step = make_array(0, ScaleKDimPerBlock); + constexpr ScaleBDramTileWindowStep scale_b_dram_tile_window_step = make_array(0, ScaleKDimPerBlock); // Helper function to load scales - auto load_scales_ = [&](auto& scale_a, auto& scale_b) { - // Load scales for each M/N iteration - // Create tile windows from scratch with correct origins for each iteration - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // Scale A: create window at origin {base_m + mIter * MPerXdl, base_k} - auto scale_a_origin = scale_a_base_origin; - scale_a_origin[number<0>{}] += mIter * MPerXdl; - - auto scale_a_tile_window = make_tile_window( - scale_a_tensor_view, - make_tuple(number{}, number{}), - scale_a_origin, - Policy::template MakeMX_ScaleA_DramTileDistribution()); - - scale_a(mIter) = load_tile(scale_a_tile_window); - }); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // Scale B: layout is [N, K], create window at origin {base_n + nIter * NPerXdl, base_k} - auto scale_b_origin = scale_b_base_origin; - scale_b_origin[number<0>{}] += nIter * NPerXdl; - - auto scale_b_tile_window = make_tile_window( - scale_b_tensor_view, - make_tuple(number{}, number{}), - scale_b_origin, - Policy::template MakeMX_ScaleB_DramTileDistribution()); - - scale_b(nIter) = load_tile(scale_b_tile_window); - }); - - // Advance base origins to next KPerBlock - // Scale A: [M, K] -> advance in K (second dimension, index 1) - // Scale B: [N, K] -> advance in K (second dimension, index 1) - scale_a_base_origin[number<1>{}] += KPerBlock / ScaleBlockSize / KXdlPack; - scale_b_base_origin[number<1>{}] += KPerBlock / ScaleBlockSize / KXdlPack; + auto load_scales_once = [&](auto& scale_a, auto& scale_b) { + scale_a = load_tile(scale_a_dram_window); + scale_b = load_tile(scale_b_dram_window); + move_tile_window(scale_a_dram_window, scale_a_dram_tile_window_step); + move_tile_window(scale_b_dram_window, scale_b_dram_tile_window_step); }; // constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() { @@ -641,14 +574,6 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, BLdsTileDistr); auto b_lds_ld_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, BLdsTileDistr); - // auto a_lds_ld_window0 = - // make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, Policy::template MakeMX_ALDSBytes_TileDistribution()); - // auto a_lds_ld_window1 = - // make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, Policy::template MakeMX_ALDSBytes_TileDistribution()); - // auto b_lds_ld_window0 = - // make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, Policy::template MakeMX_BLDSBytes_TileDistribution()); - // auto b_lds_ld_window1 = - // make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, Policy::template MakeMX_BLDSBytes_TileDistribution()); static_assert(!(is_tile_window_linear_v) && !(is_tile_window_linear_v) && @@ -656,64 +581,6 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< !(is_tile_window_linear_v), "LDS windows must not be linear"); - // Create warp-level C tensors (one per M/N iteration) - // statically_indexed_array, MIterPerWarp> c_warp_tensors; - - // Initialize C tensors - /// TODO: create CBlockTile with block_gemm.MakeCBlockTile() - // static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // clear_tile(c_warp_tensors(mIter)(nIter)); - // }); - // }); - - // Warp GEMM loop with MX scaling - // auto warp_gemm_loop = [&](const auto& a_block_tile, const auto& b_block_tile, const auto& scale_a, const auto& scale_b) { - // // Extract A/B values from block tiles to warp iteration structure - // constexpr auto a_warp_y_lengths = - // to_sequence(typename WarpGemm::AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - // constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; - // constexpr auto b_warp_y_lengths = - // to_sequence(typename WarpGemm::BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - // constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; - - // static_for<0, KIterPerWarp, 1>{}([&](auto k_iter) { - // // Map k_iter to packed scale index and OpSel - // constexpr index_t kScalePacked = (k_iter * KPerXdl) / (ScaleBlockSize * KXdlPack); - // // constexpr index_t kScaleInPack = ((k_iter * KPerXdl) / ScaleBlockSize) % KXdlPack; - // constexpr index_t kScaleInPack = k_iter; - - // static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) { - // constexpr auto OpSelA = kScaleInPack; - - // // read A warp tensor from A block tensor - // typename WarpGemm::AWarpTensor a_warp_tensor; - - // a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data( - // merge_sequences(sequence{}, a_warp_y_index_zeros), - // merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - // static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) { - // constexpr auto OpSelB = kScaleInPack; - - // // read B warp tensor from B block tensor - // typename WarpGemm::BWarpTensor b_warp_tensor; - - // b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data( - // merge_sequences(sequence{}, b_warp_y_index_zeros), - // merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // WarpGemm{}.template operator()( - // c_warp_tensors(m_iter)(n_iter), - // a_warp_tensor, - // b_warp_tensor, - // scale_a(m_iter)(number{}).get_thread_buffer()[0], - // scale_b(n_iter)(number{}).get_thread_buffer()[0]); - // }); - // }); - // }); - // }; - // write to LDS window(0) must complete before the local prefetch block_sync_lds_direct_load(); // read A(0), B(0) from LDS window(0) to pipeline registers(0) @@ -729,11 +596,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step); // Load scales for iteration 0 (ping) - load_scales_(scale_a_tile_ping, scale_b_tile_ping); + load_scales_once(scale_a_tile_ping, scale_b_tile_ping); // Load scales for iteration 1 (pong) if needed if (num_loop > 1) { - load_scales_(scale_a_tile_pong, scale_b_tile_pong); + load_scales_once(scale_a_tile_pong, scale_b_tile_pong); } if(HasHotLoop) @@ -761,14 +628,10 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< b_dram_tile_window_step); // C(i-3) = A(i-3) @ B(i-3) with MX scaling block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); - // block_gemm(c_block_tile, a_block_tile0, b_block_tile0); - // /// TODO: remove these after creating a block gemm with scales - // ignore = scale_a_tile_ping; - // ignore = scale_b_tile_ping; HotLoopScheduler(); // Load scales for iteration i+2 (ping) - if (i_global_read + 2 < num_loop) { - load_scales_(scale_a_tile_ping, scale_b_tile_ping); + if (i_global_read - 1 < num_loop) { + load_scales_once(scale_a_tile_ping, scale_b_tile_ping); } } // pong @@ -798,7 +661,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< // Load scales for iteration i+2 (pong) /// TODO: check condition if (i_global_read + 2 < num_loop) { - load_scales_(scale_a_tile_pong, scale_b_tile_pong); + load_scales_once(scale_a_tile_pong, scale_b_tile_pong); } } i_global_read += 2; @@ -818,7 +681,9 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< // /// TODO: remove these after creating a block gemm with scales // ignore = scale_a_tile_ping; // ignore = scale_b_tile_ping; - /// TODO: load next scales to ping for the last iteration + + // load last scales to ping for the last iteration to ping buffers + load_scales_once(scale_a_tile_ping, scale_b_tile_ping); } { // write to LDS window(0) must complete before the local prefetch @@ -845,54 +710,52 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< else if(TailNum == TailNumber::Two) // 2 block gemms remaining { + if (get_block_id() == 0 && (get_thread_id() == 0 || get_thread_id() == 16)) + { + int32_t a_ping = scale_a_tile_ping.get_thread_buffer()[0]; + uint8_t* a_ping_bytes = reinterpret_cast(&a_ping); + int32_t b_ping = scale_b_tile_ping.get_thread_buffer()[0]; + uint8_t* b_ping_bytes = reinterpret_cast(&b_ping); + int32_t a_pong = scale_a_tile_pong.get_thread_buffer()[0]; + uint8_t* a_pong_bytes = reinterpret_cast(&a_pong); + int32_t b_pong = scale_b_tile_pong.get_thread_buffer()[0]; + uint8_t* b_pong_bytes = reinterpret_cast(&b_pong); + printf("[tid=%d]: ScaleA ping: [%d,%d,%d,%d], ScaleB ping: [%d,%d,%d,%d]\n", get_thread_id(), + a_ping_bytes[0], a_ping_bytes[1], a_ping_bytes[2], a_ping_bytes[3], + b_ping_bytes[0], b_ping_bytes[1], b_ping_bytes[2], b_ping_bytes[3]); + printf("[tid=%d]: ScaleA pong: [%d,%d,%d,%d], ScaleB pong: [%d,%d,%d,%d]\n", get_thread_id(), + a_pong_bytes[0], a_pong_bytes[1], a_pong_bytes[2], a_pong_bytes[3], + b_pong_bytes[0], b_pong_bytes[1], b_pong_bytes[2], b_pong_bytes[3]); + } { // read A(num_loop), B(num_loop) from LDS window(1) to pipeline registers(1) Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v); Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v); - // C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); - // block_gemm(c_block_tile, a_block_tile0, b_block_tile0); - // /// TODO: remove these after creating a block gemm with scales - // ignore = scale_a_tile_ping; - // ignore = scale_b_tile_ping; } { // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling block_gemm(c_block_tile, a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong); - // block_gemm(c_block_tile, a_block_tile1, b_block_tile1); - // /// TODO: remove these after creating a block gemm with scales - // ignore = scale_a_tile_pong; - // ignore = scale_b_tile_pong; } } else if(TailNum == TailNumber::One) { + if (get_block_id() == 0 && (get_thread_id() == 0 || get_thread_id() == 16)) + { + int32_t a_ping = scale_a_tile_ping.get_thread_buffer()[0]; + uint8_t* a_ping_bytes = reinterpret_cast(&a_ping); + int32_t b_ping = scale_b_tile_ping.get_thread_buffer()[0]; + uint8_t* b_ping_bytes = reinterpret_cast(&b_ping); + printf("[tid=%d]: ScaleA ping: [%d,%d,%d,%d], ScaleB ping: [%d,%d,%d,%d]\n", get_thread_id(), + a_ping_bytes[0], a_ping_bytes[1], a_ping_bytes[2], a_ping_bytes[3], + b_ping_bytes[0], b_ping_bytes[1], b_ping_bytes[2], b_ping_bytes[3]); + } block_sync_lds(); // C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling block_gemm(c_block_tile, a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping); - // block_gemm(c_block_tile, a_block_tile0, b_block_tile0); - // /// TODO: remove these after creating a block gemm with scales - // ignore = scale_a_tile_ping; - // ignore = scale_b_tile_ping; __builtin_amdgcn_sched_barrier(0); } - // Convert warp-level C tensors to block tile format - // auto c_block_tile = BlockGemm{}.MakeCBlockTile(); - // using CWarpDstr = typename WarpGemm::CWarpDstr; - // constexpr auto c_warp_y_lengths = - // to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - // constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - // static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // c_block_tile.set_y_sliced_thread_data( - // merge_sequences(sequence{}, c_warp_y_index_zeros), - // merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - // c_warp_tensors(mIter)(nIter).get_thread_buffer()); - // }); - // }); - return c_block_tile; } }; diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index e50dd388c7..4f3ecbb680 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -206,6 +206,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy using BlockWarps = typename BlockGemmShape::BlockWarps; using WarpTile = typename BlockGemmShape::WarpTile; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MWarp = BlockWarps::at(number<0>{}); constexpr index_t NWarp = BlockWarps::at(number<1>{}); constexpr index_t MPerXdl = WarpTile::at(number<0>{}); @@ -213,6 +214,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy constexpr index_t ScaleKDimPerBlock = KPerBlock / BlockScaleSize / KXdlPack; // int32s per block constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 64/16 = 4 threads in K dimension // constexpr index_t KPackedElementsPerThread = ScaleKDimPerBlock / K_Lane; // 4/4 = 1 for K=512 + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl); // Scale A: [MWarp * MPerXdl, ScaleKDimPerBlock] warp-level tile // For K=512: [16, 4], distribute 4 int32s across 4 K_Lane threads (1 each) @@ -220,12 +222,12 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy // Distribution: Replicate in M dimension, distribute in K dimension (no vectorization - scalar loads) return make_static_tile_distribution( tile_distribution_encoding, // repeat over NWarps - tuple, // M dimension - sequence>, // K dimension + tuple, // M dimension + sequence>, // K dimension tuple, sequence<2, 1>>, // , - tuple, sequence<1, 1>>, - sequence<2, 2>, // ScaleKDimPerBlock, all int32 needed to cover KPerBlock - sequence<0, 2>>{}); + tuple, sequence<1, 2>>, + sequence<1, 2>, // + sequence<1, 0>>{}); } template @@ -235,6 +237,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy using BlockWarps = typename BlockGemmShape::BlockWarps; using WarpTile = typename BlockGemmShape::WarpTile; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t MWarp = BlockWarps::at(number<0>{}); constexpr index_t NWarp = BlockWarps::at(number<1>{}); constexpr index_t NPerXdl = WarpTile::at(number<1>{}); @@ -242,7 +245,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy constexpr index_t ScaleKDimPerBlock = KPerBlock / BlockScaleSize / KXdlPack; // int32s per block constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 64/16 = 4 threads in K dimension // constexpr index_t KPackedElementsPerThread = ScaleKDimPerBlock / K_Lane; // 4/4 = 1 for K=512 - + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl); // Scale B: [NWarp * NPerXdl, ScaleKDimPerBlock] warp-level tile // Viewed as [N, K] = [64, 4] for K=512 (access pattern, not storage) // For K=512: [64, 4], distribute 4 int32s across 4 K_Lane threads (1 each) @@ -250,12 +253,12 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy // Distribution: Replicate in N dimension, distribute in K dimension return make_static_tile_distribution( tile_distribution_encoding, // repeat over MWarps - tuple, // N dimension (first) - sequence>, // K dimension (second) - tuple, sequence<2, 1>>, // which direction - tuple, sequence<1, 1>>, // which index - sequence<2, 2>, // replicate N - sequence<0, 2>>{}); + tuple, // N dimension (first) + sequence>, // K dimension (second) + tuple, sequence<2, 1>>, // , + tuple, sequence<0, 2>>, + sequence<1, 2>, // + sequence<0, 1>>{}); } }; } // namespace ck_tile