fix, reference function passed, next check kernel function

This commit is contained in:
mtgu0705
2025-09-16 03:01:12 -05:00
parent ec9bcef591
commit 0a89ed13a5
4 changed files with 49 additions and 13 deletions

View File

@@ -307,7 +307,7 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "32", "m dimension")
.insert("n", "128", "n dimension")
.insert("k", "512", "k dimension")
.insert("k", "256", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")

View File

@@ -100,6 +100,13 @@ int run_mx_flatmm_with_layouts(int argc,
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_b);
}
else if(init_method == 3)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_b);
}
#if 1
#if 0
@@ -108,7 +115,11 @@ int run_mx_flatmm_with_layouts(int argc,
{
for(int k = 0; k < K; k++)
{
printf("%.2f ", ck_tile::type_convert<float>(a_host(m, k)));
auto a_f4x2 = a_host(m, k);
if(k % 2 == 0)
printf("%.2f ", ck_tile::type_convert<float>(a_f4x2.unpack(ck_tile::number<1>{})));
else
printf("%.2f ", ck_tile::type_convert<float>(a_f4x2.unpack(ck_tile::number<0>{})));
}
printf("\n");
}
@@ -140,6 +151,17 @@ int run_mx_flatmm_with_layouts(int argc,
#endif
printf("\n");
printf("printf scale_a: \n");
for(int m = 0; m < M / ScaleGranularityM; m++)
{
for(int k = 0; k < K / ScaleGranularityK; k++)
{
printf("%.2f ", ck_tile::type_convert<float>(scale_a(m, k)));
}
printf("\n");
}
printf("\n");
printf("printf scale_b: \n");
for(int n = 0; n < N / ScaleGranularityN; n++)
{

View File

@@ -657,7 +657,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
scale_a_window.get_bottom_tensor_view(),
make_tuple(number<MWarp * WG::kM>{}, number<64 / WG::kM>{}),
scale_a_window.get_window_origin(),
PipelinePolicy::template MakeMXFP4_ScaleA_DramTileDistribution<Problem>());
PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution<Problem>());
auto scale_b_dram_window = make_tile_window(
scale_b_window.get_bottom_tensor_view(),
@@ -770,6 +770,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
__builtin_amdgcn_sched_barrier(0);
#if 1
if(blockIdx.x == 0)
{
printf("tid: %u, scale_a_tile_tensor_ping(0)(0)[0]: 0x%08x\n",
threadIdx.x,
*(reinterpret_cast<uint32_t*>(
&scale_a_tile_tensor_ping(I0)(I0).get_thread_buffer()[0])));
}
#endif
// MAIN LOOP
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)

View File

@@ -286,14 +286,16 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0);
constexpr index_t M_Lane = TileShape::WarpTile::at(I0);
constexpr index_t N_Wrap = TileShape::BlockWarps::at(number<1>{});
constexpr index_t MWavePerBlk = M_Warp;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, // ?
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<1>, sequence<2, 1>>, // which direction
tuple<sequence<0>, sequence<0, 1>>, // which index
tile_distribution_encoding<sequence<N_Wrap>, // ?
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});
@@ -311,14 +313,16 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
constexpr index_t M_Wrap = TileShape::BlockWarps::at(number<0>{});
constexpr index_t NWavePerBlk = N_Warp;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, // ?
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<1>, sequence<2, 1>>, // which direction
tuple<sequence<0>, sequence<0, 1>>, // which index
tile_distribution_encoding<sequence<M_Wrap>, // ?
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<0, 1>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});