mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
@@ -247,14 +247,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
}
|
||||
|
||||
private:
|
||||
// A[K, M]
|
||||
// A[K0, M0, M1, M2, K1]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
|
||||
|
||||
// B[K, N]
|
||||
// B[K0, N0, N1, N2, K1]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<K0>{}, I1, I1, I1, Number<K1>{}));
|
||||
|
||||
// C[M, N]
|
||||
static constexpr auto c_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
|
||||
@@ -545,7 +545,7 @@ struct MfmaSelector
|
||||
selected_mfma.k_per_blk;
|
||||
}
|
||||
|
||||
static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; }
|
||||
static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; }
|
||||
};
|
||||
|
||||
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack>
|
||||
@@ -708,7 +708,7 @@ struct XdlopsGemm
|
||||
static constexpr auto mfma_instr = mfma.selected_mfma;
|
||||
|
||||
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
|
||||
static constexpr auto K1PerXdlops = mfma.GetKPerThread();
|
||||
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
|
||||
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
|
||||
|
||||
__host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
|
||||
|
||||
Reference in New Issue
Block a user