updated, function passed.

This commit is contained in:
mtgu0705
2025-09-17 03:58:00 -05:00
parent ce26d9071e
commit 80c1743034
4 changed files with 47 additions and 70 deletions

View File

@@ -294,7 +294,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run A16W4_Flatmm kernel "
std::cout << "Run MXFP4_Flatmm kernel "
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;

View File

@@ -111,15 +111,15 @@ int run_mx_flatmm_with_layouts(int argc,
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_b);
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_b);
}
else if(init_method == 5)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_b);
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_b);
}
else
{
@@ -297,7 +297,7 @@ int run_mx_flatmm_with_layouts(int argc,
c_dev_buf.FromDevice(c_rslt_host.data());
#if 1
#if 0
printf("printf c_rslt_host: \n");
for(int m = 0; m < M; m++)
{
@@ -329,6 +329,18 @@ int run_mx_flatmm_with_layouts(int argc,
pass = ck_tile::check_err(
c_rslt_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
#if 0
printf("printf c_m_n_host_ref: \n");
for(int m = 0; m < M; m++)
{
for(int n = 0; n < N; n++)
{
printf("%.1f ", ck_tile::type_convert<float>(c_m_n_host_ref(m, n)));
}
printf("\n");
}
#endif
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;

View File

@@ -586,12 +586,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
auto packed_m_idx = mIter / number<MXdlPack>{};
auto packed_m_rank = mIter % number<MXdlPack>{};
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank,
kIter * KPerBlockPerIter});
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank,
kIter * KPerBlockPerIter});
move_tile_window(
a_warp_windows_ping(mIter)(kIter),
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
kIter * KPerBlockPerIter});
move_tile_window(
a_warp_windows_pong(mIter)(kIter),
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
kIter * KPerBlockPerIter});
});
});
@@ -708,6 +710,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
});
});
// move Scale A window to next K
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
@@ -720,6 +723,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
});
});
// move Scale B window to next K
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
// A_Lds_TileDist may differ with ADramTileDistribution
@@ -767,7 +771,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
@@ -1012,6 +1016,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
iCounter--;
}
// TAIL
@@ -1030,7 +1036,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});

View File

@@ -46,21 +46,11 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
number<KPack>{},
number<1>{});
// constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
// a_lds_block_desc,
// make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
// number<KPerBlock / KPack>{})), // xor on M
// make_pass_through_transform(number<KPack>{})),
// make_tuple(sequence<1, 0>{}, sequence<2>{}),
// make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
number<ContiguousThreadsCntInDS_READ_16B>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(
make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
@@ -118,21 +108,22 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
constexpr int NWaves = TileShape::BlockWarps::at(number<1>{});
constexpr int M0 = TileShape::WarpTile::at(I0);
constexpr int M_warps = TileShape::BlockWarps::at(number<0>{});
constexpr int N_warps = TileShape::BlockWarps::at(number<1>{});
constexpr int M_Lane = TileShape::WarpTile::at(I0);
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I0); // 4
constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 8
constexpr int K0 = K_Lane; // 4
constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 32
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWaves>,
tuple<sequence<M0, MXdlPack>, sequence<K0, K1>>,
tuple<sequence<0>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 0>>,
sequence<2>,
sequence<1>>{});
tile_distribution_encoding<
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 2>>,
sequence<2>,
sequence<1>>{});
}
template <typename Problem>
@@ -167,38 +158,6 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
sequence<2>>{});
}
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeFp4ScaleBFlatDramTileDistribution()
// {
// using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
// constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t WaveSize = get_warp_size();
// constexpr index_t WaveNum = BlockSize / WaveSize;
// constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
// constexpr index_t XDLPerBlock = TileShape::kK / TileShape::WarpTile::at(I2);
// constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
// constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
// constexpr index_t NWavePerBlk = N_Warp;
// return make_static_tile_distribution(
// tile_distribution_encoding<
// sequence<>, // ?
// tuple<sequence<NWavePerBlk>, // second direction
// sequence<K_Lane, N_Lane, N_Pack * K_Pack>>, // first
// // direction
// // wave in blk, // thd in wave
// // <M, K> // <M, K>
// tuple<sequence<1>, sequence<2, 2>>, // which direction
// tuple<sequence<0>, sequence<0, 1>>, // which index
// // <repeat, vec_load>
// sequence<2>,
// sequence<2>>{});
// }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution()
{