mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
fix, reference function passed, next check kernel function
This commit is contained in:
@@ -307,7 +307,7 @@ auto create_args(int argc, char* argv[])
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "32", "m dimension")
|
||||
.insert("n", "128", "n dimension")
|
||||
.insert("k", "512", "k dimension")
|
||||
.insert("k", "256", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
|
||||
@@ -100,6 +100,13 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 3)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
@@ -108,7 +115,11 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
printf("%.2f ", ck_tile::type_convert<float>(a_host(m, k)));
|
||||
auto a_f4x2 = a_host(m, k);
|
||||
if(k % 2 == 0)
|
||||
printf("%.2f ", ck_tile::type_convert<float>(a_f4x2.unpack(ck_tile::number<1>{})));
|
||||
else
|
||||
printf("%.2f ", ck_tile::type_convert<float>(a_f4x2.unpack(ck_tile::number<0>{})));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
@@ -140,6 +151,17 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
#endif
|
||||
printf("\n");
|
||||
|
||||
printf("printf scale_a: \n");
|
||||
for(int m = 0; m < M / ScaleGranularityM; m++)
|
||||
{
|
||||
for(int k = 0; k < K / ScaleGranularityK; k++)
|
||||
{
|
||||
printf("%.2f ", ck_tile::type_convert<float>(scale_a(m, k)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
printf("printf scale_b: \n");
|
||||
for(int n = 0; n < N / ScaleGranularityN; n++)
|
||||
{
|
||||
|
||||
@@ -657,7 +657,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
scale_a_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<MWarp * WG::kM>{}, number<64 / WG::kM>{}),
|
||||
scale_a_window.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ScaleA_DramTileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution<Problem>());
|
||||
|
||||
auto scale_b_dram_window = make_tile_window(
|
||||
scale_b_window.get_bottom_tensor_view(),
|
||||
@@ -770,6 +770,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
#if 1
|
||||
if(blockIdx.x == 0)
|
||||
{
|
||||
printf("tid: %u, scale_a_tile_tensor_ping(0)(0)[0]: 0x%08x\n",
|
||||
threadIdx.x,
|
||||
*(reinterpret_cast<uint32_t*>(
|
||||
&scale_a_tile_tensor_ping(I0)(I0).get_thread_buffer()[0])));
|
||||
}
|
||||
#endif
|
||||
|
||||
// MAIN LOOP
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
while(iCounter > 0)
|
||||
|
||||
@@ -286,14 +286,16 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0);
|
||||
constexpr index_t M_Lane = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr index_t N_Wrap = TileShape::BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t MWavePerBlk = M_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>, // ?
|
||||
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
tile_distribution_encoding<sequence<N_Wrap>, // ?
|
||||
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
@@ -311,14 +313,16 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
|
||||
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
|
||||
|
||||
constexpr index_t M_Wrap = TileShape::BlockWarps::at(number<0>{});
|
||||
|
||||
constexpr index_t NWavePerBlk = N_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
tile_distribution_encoding<sequence<M_Wrap>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
|
||||
Reference in New Issue
Block a user