mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
fix, function partially passed
This commit is contained in:
@@ -107,8 +107,26 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 4)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 5)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! only support init_method 0/1/2/3/4");
|
||||
}
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
#if 0
|
||||
printf("printf a_host: \n");
|
||||
for(int m = 0; m < M; m++)
|
||||
@@ -178,7 +196,7 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
for(int k = 0; k < K;)
|
||||
{
|
||||
printf("0x%08x ", *(reinterpret_cast<uint32_t*>(&a_host(m, k))));
|
||||
k += 2;
|
||||
k += 8;
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
@@ -279,6 +297,19 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
|
||||
#if 1
|
||||
printf("printf c_rslt_host: \n");
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
printf("%.1f ", ck_tile::type_convert<float>(c_rslt_host(m, n)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
#endif
|
||||
|
||||
bool pass = true;
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
|
||||
@@ -38,11 +38,12 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
|
||||
static constexpr int ScaleGranularityK = 32;
|
||||
|
||||
// static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
|
||||
static constexpr int MXdlPack = 2; // it's fixed for fp4
|
||||
static constexpr int NXdlPack = 2; // it's fixed for fp4
|
||||
static constexpr int KXdlPack = 2;
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
|
||||
static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
|
||||
static constexpr int MXdlPack = 2; // it's fixed for fp4
|
||||
static constexpr int NXdlPack = 2; // it's fixed for fp4
|
||||
static constexpr int KXdlPack = 2;
|
||||
// static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
|
||||
static constexpr index_t flatKPerWarp = 64 * ContinuousKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = MXF4FlatmmPipelineAgBgCrPolicy>
|
||||
@@ -577,24 +578,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
// auto A_Lds_Stride = 8;
|
||||
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
// a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
// auto weight_k_idx = kIter / number<XDL_PerWeightK>{};
|
||||
// auto weight_k_rank = kIter % number<XDL_PerWeightK>{};
|
||||
// move_tile_window(
|
||||
// a_warp_windows_ping(mIter)(kIter),
|
||||
// {mIter * MPerBlockPerIter,
|
||||
// weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
|
||||
// move_tile_window(
|
||||
// a_warp_windows_pong(mIter)(kIter),
|
||||
// {mIter * MPerBlockPerIter,
|
||||
// weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
|
||||
// });
|
||||
// });
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
@@ -603,14 +586,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
auto packed_m_idx = mIter / number<MXdlPack>{};
|
||||
auto packed_m_rank = mIter % number<MXdlPack>{};
|
||||
|
||||
move_tile_window(
|
||||
a_warp_windows_ping(mIter)(kIter),
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
|
||||
kIter * KPerBlockPerIter});
|
||||
move_tile_window(
|
||||
a_warp_windows_pong(mIter)(kIter),
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
|
||||
kIter * KPerBlockPerIter});
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank,
|
||||
kIter * KPerBlockPerIter});
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank,
|
||||
kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -770,16 +751,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
#if 1
|
||||
if(blockIdx.x == 0)
|
||||
{
|
||||
printf("tid: %u, scale_a_tile_tensor_ping(0)(0)[0]: 0x%08x\n",
|
||||
threadIdx.x,
|
||||
*(reinterpret_cast<uint32_t*>(
|
||||
&scale_a_tile_tensor_ping(I0)(I0).get_thread_buffer()[0])));
|
||||
}
|
||||
#endif
|
||||
|
||||
// MAIN LOOP
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
while(iCounter > 0)
|
||||
|
||||
@@ -155,9 +155,8 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // second >>>>>>>>>need to double confirm
|
||||
// direction
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>,
|
||||
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
|
||||
@@ -1538,7 +1538,7 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
CVecType{0.f},
|
||||
c_vec,
|
||||
4,
|
||||
4,
|
||||
opselA,
|
||||
|
||||
Reference in New Issue
Block a user