This commit is contained in:
mtgu0705
2025-09-15 00:05:18 -05:00
parent 22586c3135
commit cc94eb6045
2 changed files with 66 additions and 61 deletions

View File

@@ -844,14 +844,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack),
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -973,14 +973,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B
scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
WG{}(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack), // scale B
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1100,14 +1101,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
WG{}(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), // scale B
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1174,14 +1176,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B
scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
WG{}(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack), // scale B
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1242,14 +1245,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
WG{}(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), // scale B
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(

View File

@@ -40,39 +40,40 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(make_xor_with_modulo_transform(make_tuple(
number<MPerBlock>{}, number<KPerBlock / KPack>{})), // xor on M
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
// constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
// constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
// a_lds_block_desc_0,
// a_lds_block_desc,
// make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
// number<ContiguousThreadsCntInDS_READ_16B>{})),
// number<KPerBlock / KPack>{})), // xor on M
// make_pass_through_transform(number<KPack>{})),
// make_tuple(sequence<1, 0>{}, sequence<2>{}),
// make_tuple(sequence<1, 0>{}, sequence<2>{}));
// constexpr auto a_lds_block_desc = transform_tensor_descriptor(
// a_lds_block_desc_permuted,
// make_tuple(make_pass_through_transform(number<MPerBlock>{}),
// make_merge_transform_v3_division_mod(
// make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
// make_tuple(sequence<1>{}, sequence<0, 2>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
return a_lds_block_desc_permuted;
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
number<ContiguousThreadsCntInDS_READ_16B>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// return a_lds_block_desc_permuted;
return a_lds_block_desc;
}
template <typename Problem>