Improve the general performance of the Preshuffled GEMM V3 & delete the unnecessary instances (#2166)

* make the work compiled

* Solved the example code, but still have the profiler error

* Finished the feature

* Clang format and update the CHANGELOG

* solve the preshuffle v1 & v2 problem

* Comment Addressed

* Comment Addressed
This commit is contained in:
Thomas Ning
2025-05-12 09:52:58 -07:00
committed by GitHub
parent 9d1e44e56a
commit b49f7de81f
11 changed files with 445 additions and 918 deletions

View File

@@ -167,11 +167,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static constexpr index_t KPackPerGroup = KPack / KGroup;
static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup;
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static constexpr auto MakeDsGridPointer()
{
@@ -209,7 +211,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}
__host__ __device__ static auto CalculateBK0Shuffled(index_t K)
{
return math::integer_divide_ceil(K, KLane * KPack);
return math::integer_divide_ceil(K, KLane * KPackPerGroup);
}
__host__ __device__ static auto CalculateKPadded(index_t K)
@@ -351,7 +353,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
constexpr index_t NkSwizzleNumber = Number<warpSize * KPackPerGroup>{};
return make_naive_tensor_descriptor(
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
@@ -1228,7 +1230,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPackPerGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds
@@ -1668,7 +1670,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack * (get_thread_local_1d_id() % warpSize)));
KPackPerGroup * (get_thread_local_1d_id() % warpSize)));
// LDS allocation for A and B: be careful of alignment
// Cast after lds