updates, build pass

This commit is contained in:
mtgu0705
2025-09-15 03:03:02 -05:00
parent cc94eb6045
commit 9ceb3fd508
3 changed files with 68 additions and 62 deletions

View File

@@ -284,12 +284,14 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
const std::size_t ScaleBlockSize = K / scale_a.get_length(1);
HostTensor<AccDataType> a_m_k_scaled({M, K}, {K, 1});
HostTensor<AccDataType> b_k_n_scaled({K, N}, {1, N});
HostTensor<AccDataType> a_m_k_scaled({std::size_t(M), std::size_t(K)},
{std::size_t(K), std::size_t(1)});
HostTensor<AccDataType> b_k_n_scaled({std::size_t(K), std::size_t(N)},
{std::size_t(1), std::size_t(K)});
for(int m = 0; m < M; ++m)
for(std::size_t m = 0; m < M; ++m)
{
for(int k = 0; k < K; ++k)
for(std::size_t k = 0; k < K; ++k)
{
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
@@ -297,7 +299,7 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
continue; // skip odd k
auto a_f4x2 = a_m_k(m, k);
auto a_scale = scale_a(m, k / ScaleBlockSize);
auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
auto a_f4_lo =
@@ -311,9 +313,9 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
}
}
for(int n = 0; n < N; n++)
for(std::size_t n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
for(std::size_t k = 0; k < K; k++)
{
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
@@ -321,7 +323,7 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
continue; // skip odd k
auto b_f4x2 = b_k_n(k, n);
auto b_scale = scale_b(k / ScaleBlockSize, n);
auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
auto b_f4_lo =

View File

@@ -76,7 +76,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = FlatmmPipeline::BlockSize().x;
constexpr int block_size = MXFlatmmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
@@ -86,7 +86,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size,
FlatmmPipeline,
MXFlatmmKernel,
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
block_size,
dync_smem_size);

View File

@@ -118,8 +118,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
static constexpr index_t APackedSize = numeric_traits<ADataType>::packed_size;
static constexpr index_t BPackedSize = numeric_traits<BDataType>::packed_size;
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr index_t MXdlPack = Problem::MXdlPack;
static constexpr index_t NXdlPack = Problem::NXdlPack;
@@ -629,25 +629,27 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
using V4UInt_Buffer = thread_buffer<uint32_t, 4>;
union UnionB
{
V4UInt_Buffer u = 0;
MXFP4_Buffer mxfp4;
} ub;
// using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
// // use v4i32 as the data type between basicblock to avoid unpack and repack operation.
// using V4UInt_Buffer = thread_buffer<uint32_t, 4>;
// union UnionB
// {
// V4UInt_Buffer u = 0;
// MXFP4_Buffer mxfp4;
// } ub;
// pingpong buffer for B
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, KIterPerWarp>,
NIterPerWarp>
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, KIterPerWarp>,
NIterPerWarp>
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_pong;
// pingpong buffer for Scale A and Scale B
@@ -708,8 +710,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_ping(nIter)(kIter) = ub.u;
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// move B window to next flat K
@@ -785,8 +786,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = ub.u;
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
@@ -850,8 +850,10 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack),
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack));
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0],
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -914,8 +916,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_ping(nIter)(kIter) = ub.u;
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
@@ -973,15 +974,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack), // scale B
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]); // scale B
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1047,8 +1049,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = ub.u;
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
@@ -1101,15 +1102,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), // scale B
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]); // scale B
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1176,15 +1178,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack), // scale B
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]); // scale B
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1245,15 +1248,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack), // scale B
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack), // scale A
ikxdl * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]); // scale B
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(