From 9ceb3fd508338f0416f8937c4319729bb751f37b Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Mon, 15 Sep 2025 03:03:02 -0500 Subject: [PATCH] updates, build pass --- .../ck_tile/host/reference/reference_gemm.hpp | 18 +-- .../ops/flatmm/kernel/mx_flatmm_kernel.hpp | 4 +- ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 108 +++++++++--------- 3 files changed, 68 insertions(+), 62 deletions(-) diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 2c56f7acb9..37eb0cee82 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -284,12 +284,14 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, const std::size_t ScaleBlockSize = K / scale_a.get_length(1); - HostTensor a_m_k_scaled({M, K}, {K, 1}); - HostTensor b_k_n_scaled({K, N}, {1, N}); + HostTensor a_m_k_scaled({std::size_t(M), std::size_t(K)}, + {std::size_t(K), std::size_t(1)}); + HostTensor b_k_n_scaled({std::size_t(K), std::size_t(N)}, + {std::size_t(1), std::size_t(K)}); - for(int m = 0; m < M; ++m) + for(std::size_t m = 0; m < M; ++m) { - for(int k = 0; k < K; ++k) + for(std::size_t k = 0; k < K; ++k) { if constexpr(std::is_same_v) { @@ -297,7 +299,7 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, continue; // skip odd k auto a_f4x2 = a_m_k(m, k); - auto a_scale = scale_a(m, k / ScaleBlockSize); + auto a_scale = ck_tile::type_convert(scale_a(m, k / ScaleBlockSize)); // auto f4_lo = ck_tile::type_convert(f4x2)[0]; // auto f4_hi = ck_tile::type_convert(f4x2)[1]; auto a_f4_lo = @@ -311,9 +313,9 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, } } - for(int n = 0; n < N; n++) + for(std::size_t n = 0; n < N; n++) { - for(int k = 0; k < K; k++) + for(std::size_t k = 0; k < K; k++) { if constexpr(std::is_same_v) { @@ -321,7 +323,7 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, continue; // skip odd k auto b_f4x2 = b_k_n(k, n); - auto b_scale = scale_b(k / ScaleBlockSize, n); + auto b_scale = ck_tile::type_convert(scale_b(k / ScaleBlockSize, n)); // auto f4_lo = ck_tile::type_convert(f4x2)[0]; // auto f4_hi = ck_tile::type_convert(f4x2)[1]; auto b_f4_lo = diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index b2a5f39793..763a2a9d6e 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -76,7 +76,7 @@ struct MXFlatmmKernel : FlatmmKernel( kentry2>), block_size, dync_smem_size); diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 052d77a470..c0b9710211 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -118,8 +118,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1::packed_size; - static constexpr index_t BPackedSize = numeric_traits::packed_size; + static constexpr index_t APackedSize = numeric_traits::PackedSize; + static constexpr index_t BPackedSize = numeric_traits::PackedSize; static constexpr index_t MXdlPack = Problem::MXdlPack; static constexpr index_t NXdlPack = Problem::NXdlPack; @@ -629,25 +629,27 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1; - union UnionB - { - V4UInt_Buffer u = 0; - MXFP4_Buffer mxfp4; - } ub; + // using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window)); + // // use v4i32 as the data type between basicblock to avoid unpack and repack operation. + // using V4UInt_Buffer = thread_buffer; + // union UnionB + // { + // V4UInt_Buffer u = 0; + // MXFP4_Buffer mxfp4; + // } ub; // pingpong buffer for B statically_indexed_array< statically_indexed_array, NIterPerWarp> b_flat_dram_windows; - statically_indexed_array, - NIterPerWarp> + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> b_warp_tensor_ping; - statically_indexed_array, - NIterPerWarp> + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> b_warp_tensor_pong; // pingpong buffer for Scale A and Scale B @@ -708,8 +710,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}), b_warp_tensor_ping(nIter_pack * number{} + inxdl)( kIter_pack * number{} + ikxdl), - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)); + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( @@ -914,8 +916,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); // warp GEMM - WG{}( + WG{}.template + operator()( c_warp_tensor, a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * NXdlPack + - inxdl)(kIter_pack * KXdlPack + ikxdl), - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack), // scale B - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack), // scale A - ikxdl * MXdlPack + imxdl, // A opsel - ikxdl * NXdlPack + inxdl); // B opsel + b_warp_tensor_pong(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( @@ -1047,8 +1049,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); // warp GEMM - WG{}( + WG{}.template + operator()( c_warp_tensor, a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * NXdlPack + - inxdl)(kIter_pack * KXdlPack + ikxdl), - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), // scale B - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack), // scale A - ikxdl * MXdlPack + imxdl, // A opsel - ikxdl * NXdlPack + inxdl); // B opsel + b_warp_tensor_ping(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( @@ -1176,15 +1178,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); // warp GEMM - WG{}( + WG{}.template + operator()( c_warp_tensor, a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * NXdlPack + - inxdl)(kIter_pack * KXdlPack + ikxdl), - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack), // scale B - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack), // scale A - ikxdl * MXdlPack + imxdl, // A opsel - ikxdl * NXdlPack + inxdl); // B opsel + b_warp_tensor_pong(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( @@ -1245,15 +1248,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); // warp GEMM - WG{}( + WG{}.template + operator()( c_warp_tensor, a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * NXdlPack + - inxdl)(kIter_pack * KXdlPack + ikxdl), - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), // scale B - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack), // scale A - ikxdl * MXdlPack + imxdl, // A opsel - ikxdl * NXdlPack + inxdl); // B opsel + b_warp_tensor_ping(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data(