From ce26d9071e9a93e1fa2c9e87f81632cc4fcfe16f Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Tue, 16 Sep 2025 22:21:39 -0500 Subject: [PATCH] fix, function partially passed --- .../18_flatmm/mxgemm/run_mx_flatmm.inc | 35 +++++++++++- ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 53 +++++-------------- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 5 +- .../warp/warp_gemm_attribute_mfma_impl.hpp | 2 +- 4 files changed, 48 insertions(+), 47 deletions(-) diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index 4aaae4cfe7..1c1e5999f8 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -107,8 +107,26 @@ int run_mx_flatmm_with_layouts(int argc, ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_a); ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_b); } + else if(init_method == 4) + { + ck_tile::FillUniformDistribution{0.0f, 1.0f}(a_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution{1.f, 1.f}(scale_b); + } + else if(init_method == 5) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution{1.f, 1.f}(scale_b); + } + else + { + throw std::runtime_error("wrong! only support init_method 0/1/2/3/4"); + } -#if 1 +#if 0 #if 0 printf("printf a_host: \n"); for(int m = 0; m < M; m++) @@ -178,7 +196,7 @@ int run_mx_flatmm_with_layouts(int argc, for(int k = 0; k < K;) { printf("0x%08x ", *(reinterpret_cast(&a_host(m, k)))); - k += 2; + k += 8; } printf("\n"); } @@ -279,6 +297,19 @@ int run_mx_flatmm_with_layouts(int argc, c_dev_buf.FromDevice(c_rslt_host.data()); +#if 1 + printf("printf c_rslt_host: \n"); + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + printf("%.1f ", ck_tile::type_convert(c_rslt_host(m, n))); + } + printf("\n"); + } + printf("\n"); +#endif + bool pass = true; if(arg_parser.get_int("v") == 1) { diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 8eb09d0b5f..fb05e48917 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -38,11 +38,12 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem @@ -577,24 +578,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 a_warp_windows_pong; - // auto A_Lds_Stride = 8; - // static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - // a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - // a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - - // auto weight_k_idx = kIter / number{}; - // auto weight_k_rank = kIter % number{}; - // move_tile_window( - // a_warp_windows_ping(mIter)(kIter), - // {mIter * MPerBlockPerIter, - // weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK}); - // move_tile_window( - // a_warp_windows_pong(mIter)(kIter), - // {mIter * MPerBlockPerIter, - // weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK}); - // }); - // }); static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; @@ -603,14 +586,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}; auto packed_m_rank = mIter % number{}; - move_tile_window( - a_warp_windows_ping(mIter)(kIter), - {packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM, - kIter * KPerBlockPerIter}); - move_tile_window( - a_warp_windows_pong(mIter)(kIter), - {packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM, - kIter * KPerBlockPerIter}); + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank, + kIter * KPerBlockPerIter}); + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank, + kIter * KPerBlockPerIter}); }); }); @@ -770,16 +751,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( - &scale_a_tile_tensor_ping(I0)(I0).get_thread_buffer()[0]))); - } -#endif - // MAIN LOOP index_t iCounter = (num_loop - 1) / 2; while(iCounter > 0) diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index d5574fbd79..274f6d1107 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -155,9 +155,8 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy return make_static_tile_distribution( tile_distribution_encoding< - sequence, // ? - tuple, // second >>>>>>>>>need to double confirm - // direction + sequence, + tuple, sequence>, // first direction // wave in blk, // thd in wave // // diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index b212d9aa4e..e70e04d2e9 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1538,7 +1538,7 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4 c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, - CVecType{0.f}, + c_vec, 4, 4, opselA,