mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
updated, function passed.
This commit is contained in:
@@ -294,7 +294,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run A16W4_Flatmm kernel "
|
||||
std::cout << "Run MXFP4_Flatmm kernel "
|
||||
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
|
||||
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
@@ -111,15 +111,15 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_b);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 5)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_b);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -297,7 +297,7 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
printf("printf c_rslt_host: \n");
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
@@ -329,6 +329,18 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
pass = ck_tile::check_err(
|
||||
c_rslt_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
#if 0
|
||||
printf("printf c_m_n_host_ref: \n");
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
printf("%.1f ", ck_tile::type_convert<float>(c_m_n_host_ref(m, n)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
@@ -586,12 +586,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
auto packed_m_idx = mIter / number<MXdlPack>{};
|
||||
auto packed_m_rank = mIter % number<MXdlPack>{};
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank,
|
||||
kIter * KPerBlockPerIter});
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank,
|
||||
kIter * KPerBlockPerIter});
|
||||
move_tile_window(
|
||||
a_warp_windows_ping(mIter)(kIter),
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
|
||||
kIter * KPerBlockPerIter});
|
||||
move_tile_window(
|
||||
a_warp_windows_pong(mIter)(kIter),
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
|
||||
kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -708,6 +710,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
// move Scale A window to next K
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
@@ -720,6 +723,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
// move Scale B window to next K
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// A_Lds_TileDist may differ with ADramTileDistribution
|
||||
@@ -767,7 +771,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1012,6 +1016,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
// TAIL
|
||||
@@ -1030,7 +1036,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -46,21 +46,11 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
// constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
// a_lds_block_desc,
|
||||
// make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
|
||||
// number<KPerBlock / KPack>{})), // xor on M
|
||||
// make_pass_through_transform(number<KPack>{})),
|
||||
// make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
// make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
|
||||
number<ContiguousThreadsCntInDS_READ_16B>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
@@ -118,21 +108,22 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
|
||||
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
constexpr int NWaves = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M0 = TileShape::WarpTile::at(I0);
|
||||
constexpr int M_warps = TileShape::BlockWarps::at(number<0>{});
|
||||
constexpr int N_warps = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M_Lane = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
|
||||
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I0); // 4
|
||||
|
||||
constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 8
|
||||
constexpr int K0 = K_Lane; // 4
|
||||
constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 32
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWaves>,
|
||||
tuple<sequence<M0, MXdlPack>, sequence<K0, K1>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
tile_distribution_encoding<
|
||||
sequence<N_warps>,
|
||||
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -167,38 +158,6 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeFp4ScaleBFlatDramTileDistribution()
|
||||
// {
|
||||
// using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
// constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// constexpr index_t WaveSize = get_warp_size();
|
||||
// constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
// constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
|
||||
|
||||
// constexpr index_t XDLPerBlock = TileShape::kK / TileShape::WarpTile::at(I2);
|
||||
// constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
|
||||
// constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
|
||||
|
||||
// constexpr index_t NWavePerBlk = N_Warp;
|
||||
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding<
|
||||
// sequence<>, // ?
|
||||
// tuple<sequence<NWavePerBlk>, // second direction
|
||||
// sequence<K_Lane, N_Lane, N_Pack * K_Pack>>, // first
|
||||
// // direction
|
||||
// // wave in blk, // thd in wave
|
||||
// // <M, K> // <M, K>
|
||||
// tuple<sequence<1>, sequence<2, 2>>, // which direction
|
||||
// tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// // <repeat, vec_load>
|
||||
// sequence<2>,
|
||||
// sequence<2>>{});
|
||||
// }
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user