mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
updates
This commit is contained in:
@@ -844,14 +844,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack),
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -973,14 +973,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
WG{}(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1100,14 +1101,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
WG{}(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1174,14 +1176,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
WG{}(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1242,14 +1245,15 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
WG{}(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxdl * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
|
||||
@@ -40,39 +40,40 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
|
||||
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
number<MPerBlock>{}, number<KPerBlock / KPack>{})), // xor on M
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
// constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
|
||||
|
||||
// constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
// a_lds_block_desc_0,
|
||||
// a_lds_block_desc,
|
||||
// make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
|
||||
// number<ContiguousThreadsCntInDS_READ_16B>{})),
|
||||
// number<KPerBlock / KPack>{})), // xor on M
|
||||
// make_pass_through_transform(number<KPack>{})),
|
||||
// make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
// make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
// constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
// a_lds_block_desc_permuted,
|
||||
// make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
// make_merge_transform_v3_division_mod(
|
||||
// make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
// make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
|
||||
|
||||
return a_lds_block_desc_permuted;
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
|
||||
number<ContiguousThreadsCntInDS_READ_16B>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// return a_lds_block_desc_permuted;
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user