fix, function partially passed

This commit is contained in:
mtgu0705
2025-09-16 22:21:39 -05:00
parent 0a89ed13a5
commit ce26d9071e
4 changed files with 48 additions and 47 deletions

View File

@@ -107,8 +107,26 @@ int run_mx_flatmm_with_layouts(int argc,
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_b);
}
else if(init_method == 4)
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_b);
}
else if(init_method == 5)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_b);
}
else
{
throw std::runtime_error("wrong! only support init_method 0/1/2/3/4");
}
#if 1
#if 0
#if 0
printf("printf a_host: \n");
for(int m = 0; m < M; m++)
@@ -178,7 +196,7 @@ int run_mx_flatmm_with_layouts(int argc,
for(int k = 0; k < K;)
{
printf("0x%08x ", *(reinterpret_cast<uint32_t*>(&a_host(m, k))));
k += 2;
k += 8;
}
printf("\n");
}
@@ -279,6 +297,19 @@ int run_mx_flatmm_with_layouts(int argc,
c_dev_buf.FromDevice(c_rslt_host.data());
#if 1
printf("printf c_rslt_host: \n");
for(int m = 0; m < M; m++)
{
for(int n = 0; n < N; n++)
{
printf("%.1f ", ck_tile::type_convert<float>(c_rslt_host(m, n)));
}
printf("\n");
}
printf("\n");
#endif
bool pass = true;
if(arg_parser.get_int("v") == 1)
{

View File

@@ -38,11 +38,12 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
static constexpr int ScaleGranularityK = 32;
// static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
static constexpr int MXdlPack = 2; // it's fixed for fp4
static constexpr int NXdlPack = 2; // it's fixed for fp4
static constexpr int KXdlPack = 2;
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
static constexpr int MXdlPack = 2; // it's fixed for fp4
static constexpr int NXdlPack = 2; // it's fixed for fp4
static constexpr int KXdlPack = 2;
// static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
static constexpr index_t flatKPerWarp = 64 * ContinuousKPerThread;
};
template <typename Problem, typename PipelinePolicy = MXF4FlatmmPipelineAgBgCrPolicy>
@@ -577,24 +578,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
MIterPerWarp>
a_warp_windows_pong;
// auto A_Lds_Stride = 8;
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
// a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
// auto weight_k_idx = kIter / number<XDL_PerWeightK>{};
// auto weight_k_rank = kIter % number<XDL_PerWeightK>{};
// move_tile_window(
// a_warp_windows_ping(mIter)(kIter),
// {mIter * MPerBlockPerIter,
// weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
// move_tile_window(
// a_warp_windows_pong(mIter)(kIter),
// {mIter * MPerBlockPerIter,
// weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
// });
// });
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
@@ -603,14 +586,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
auto packed_m_idx = mIter / number<MXdlPack>{};
auto packed_m_rank = mIter % number<MXdlPack>{};
move_tile_window(
a_warp_windows_ping(mIter)(kIter),
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
kIter * KPerBlockPerIter});
move_tile_window(
a_warp_windows_pong(mIter)(kIter),
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
kIter * KPerBlockPerIter});
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank,
kIter * KPerBlockPerIter});
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank,
kIter * KPerBlockPerIter});
});
});
@@ -770,16 +751,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
__builtin_amdgcn_sched_barrier(0);
#if 1
if(blockIdx.x == 0)
{
printf("tid: %u, scale_a_tile_tensor_ping(0)(0)[0]: 0x%08x\n",
threadIdx.x,
*(reinterpret_cast<uint32_t*>(
&scale_a_tile_tensor_ping(I0)(I0).get_thread_buffer()[0])));
}
#endif
// MAIN LOOP
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)

View File

@@ -155,9 +155,8 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>, // ?
tuple<sequence<NWavePerBlk, NXdlPack>, // second >>>>>>>>>need to double confirm
// direction
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>,
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>

View File

@@ -1538,7 +1538,7 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
CVecType{0.f},
c_vec,
4,
4,
opselA,