mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
optimized the VGPR repack issue for MXFP4
This commit is contained in:
@@ -612,27 +612,25 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
// using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
|
||||
// // use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
// using V4UInt_Buffer = thread_buffer<uint32_t, 4>;
|
||||
// union UnionB
|
||||
// {
|
||||
// V4UInt_Buffer u = 0;
|
||||
// MXFP4_Buffer mxfp4;
|
||||
// } ub;
|
||||
using MXFP4_B_Buffer = decltype(load_tile(b_flat_dram_window));
|
||||
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
using V4UInt_B_Buffer = thread_buffer<uint32_t, 4>;
|
||||
union UnionBuf
|
||||
{
|
||||
V4UInt_B_Buffer u = 0;
|
||||
MXFP4_B_Buffer mxfp4;
|
||||
} ub;
|
||||
|
||||
// pingpong buffer for B
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_B_Buffer, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_B_Buffer, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
// pingpong buffer for Scale A and Scale B
|
||||
@@ -693,7 +691,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_ping(nIter)(kIter) = ub.u;
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
@@ -742,16 +741,33 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
using MXFP4_A_Buffer_ping =
|
||||
decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{})));
|
||||
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
using V4UInt_A_Buffer = thread_buffer<uint32_t, 4>;
|
||||
union UnionBuf_A_ping
|
||||
{
|
||||
V4UInt_A_Buffer u = 0;
|
||||
MXFP4_A_Buffer_ping mxfp4;
|
||||
} ua_ping;
|
||||
|
||||
using MXFP4_A_Buffer_pong =
|
||||
decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{})));
|
||||
union UnionBuf_A_pong
|
||||
{
|
||||
V4UInt_A_Buffer u = 0;
|
||||
MXFP4_A_Buffer_pong mxfp4;
|
||||
} ua_pong;
|
||||
|
||||
// preload A00,A10... from lds
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
|
||||
m_preload>
|
||||
a_warp_tensor;
|
||||
statically_indexed_array<V4UInt_A_Buffer, m_preload> a_warp_tensor;
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
|
||||
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_ping.u;
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -771,7 +787,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_pong(nIter)(kIter) = ub.u;
|
||||
});
|
||||
});
|
||||
|
||||
@@ -828,13 +845,19 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
UnionBuf_A_ping ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0],
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -849,7 +872,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
@@ -858,8 +882,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
@@ -882,8 +907,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
@@ -901,7 +926,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_ping(nIter)(kIter) = ub.u;
|
||||
});
|
||||
});
|
||||
|
||||
@@ -958,13 +984,20 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_pong ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
@@ -979,7 +1012,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
@@ -988,8 +1022,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(
|
||||
ua_pong.mxfp4 = load_tile(
|
||||
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
@@ -1012,8 +1047,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_ping.u; // reload a_warp_tensor with ping buffer
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
@@ -1036,7 +1071,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_pong(nIter)(kIter) = ub.u;
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1088,13 +1124,20 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_ping ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -1109,7 +1152,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
@@ -1118,8 +1162,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
@@ -1137,8 +1182,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
|
||||
});
|
||||
|
||||
Last2ndHotLoopScheduler();
|
||||
@@ -1164,13 +1209,19 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_pong ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
@@ -1185,7 +1236,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
@@ -1194,8 +1246,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(
|
||||
ua_pong.mxfp4 = load_tile(
|
||||
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
@@ -1234,13 +1287,20 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_ping ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -1255,7 +1315,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
@@ -1264,8 +1325,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
|
||||
Reference in New Issue
Block a user