diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index aa3e0493b1..052d77a470 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -844,14 +844,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * NXdlPack + - inxdl)(kIter_pack * KXdlPack + ikxdl), - scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B - scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A - ikxdl * MXdlPack + imxdl, // A opsel - ikxdl * NXdlPack + inxdl); // B opsel + WG{}.template + operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)); // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( @@ -973,14 +973,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * NXdlPack + - inxdl)(kIter_pack * KXdlPack + ikxdl), - scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B - scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A - ikxdl * MXdlPack + imxdl, // A opsel - ikxdl * NXdlPack + inxdl); // B opsel + WG{}( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter_pack * NXdlPack + + inxdl)(kIter_pack * KXdlPack + ikxdl), + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack), // scale B + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack), // scale A + ikxdl * MXdlPack + imxdl, // A opsel + ikxdl * NXdlPack + inxdl); // B opsel // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( @@ -1100,14 +1101,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * NXdlPack + - inxdl)(kIter_pack * KXdlPack + ikxdl), - scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B - scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A - ikxdl * MXdlPack + imxdl, // A opsel - ikxdl * NXdlPack + inxdl); // B opsel + WG{}( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * NXdlPack + + inxdl)(kIter_pack * KXdlPack + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), // scale B + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack), // scale A + ikxdl * MXdlPack + imxdl, // A opsel + ikxdl * NXdlPack + inxdl); // B opsel // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( @@ -1174,14 +1176,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * NXdlPack + - inxdl)(kIter_pack * KXdlPack + ikxdl), - scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B - scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A - ikxdl * MXdlPack + imxdl, // A opsel - ikxdl * NXdlPack + inxdl); // B opsel + WG{}( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter_pack * NXdlPack + + inxdl)(kIter_pack * KXdlPack + ikxdl), + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack), // scale B + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack), // scale A + ikxdl * MXdlPack + imxdl, // A opsel + ikxdl * NXdlPack + inxdl); // B opsel // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( @@ -1242,14 +1245,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, c_warp_y_lengths)); // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * NXdlPack + - inxdl)(kIter_pack * KXdlPack + ikxdl), - scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B - scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A - ikxdl * MXdlPack + imxdl, // A opsel - ikxdl * NXdlPack + inxdl); // B opsel + WG{}( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * NXdlPack + + inxdl)(kIter_pack * KXdlPack + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), // scale B + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack), // scale A + ikxdl * MXdlPack + imxdl, // A opsel + ikxdl * NXdlPack + inxdl); // B opsel // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index b72d07bc91..22c6927c70 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -40,39 +40,40 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t APackedSize = numeric_traits::PackedSize; constexpr index_t KPack = GetSmemPackA() * APackedSize; - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), make_tuple(number{}, number{}, number<1>{}), number{}, number<1>{}); - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - number{}, number{})), // xor on M - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - // constexpr int ContiguousThreadsCntInDS_READ_16B = 4; - // constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - // a_lds_block_desc_0, + // a_lds_block_desc, // make_tuple(make_xor_transform(make_tuple(number{}, - // number{})), + // number{})), // xor on M // make_pass_through_transform(number{})), // make_tuple(sequence<1, 0>{}, sequence<2>{}), // make_tuple(sequence<1, 0>{}, sequence<2>{})); - // constexpr auto a_lds_block_desc = transform_tensor_descriptor( - // a_lds_block_desc_permuted, - // make_tuple(make_pass_through_transform(number{}), - // make_merge_transform_v3_division_mod( - // make_tuple(number{}, number{}))), - // make_tuple(sequence<1>{}, sequence<0, 2>{}), - // make_tuple(sequence<0>{}, sequence<1>{})); + constexpr int ContiguousThreadsCntInDS_READ_16B = 4; - return a_lds_block_desc_permuted; + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // return a_lds_block_desc_permuted; + return a_lds_block_desc; } template