This commit is contained in:
lalala-sh
2025-05-06 07:36:59 +00:00
parent 9c06c3817a
commit 0ab978584d
2 changed files with 13 additions and 12 deletions

View File

@@ -535,7 +535,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
@@ -631,7 +631,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
@@ -657,7 +657,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
@@ -683,7 +683,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
@@ -700,8 +700,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
});
});
}
HotLoopScheduler();
});
HotLoopScheduler();
};
LoopFunc(I0, I1);
@@ -771,7 +771,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
@@ -791,7 +791,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
@@ -811,7 +811,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I0),
@@ -868,7 +868,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
@@ -930,7 +930,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineSch
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(

View File

@@ -190,6 +190,7 @@ struct GridwiseMoeGemm
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
// static_assert(KGroup == 2, "");
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
@@ -1349,7 +1350,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
@@ -2064,7 +2065,7 @@ struct GridwiseMoeGemm
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,