From 0572b85298d28751dbedc43f4f0e9fd882b316bd Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Thu, 2 Dec 2021 23:37:57 +0000 Subject: [PATCH] renaming/comments [ROCm/composable_kernel commit: d7a0a3f94cee332fcbe181a9174491028d2620a9] --- .../include/tensor_operation/blockwise_gemm_xdlops.hpp | 5 +++-- composable_kernel/include/tensor_operation/xdlops_gemm.hpp | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index 1c9337db15..4a0253df46 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -247,14 +247,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 } private: - // A[K, M] + // A[K0, M0, M1, M2, K1] static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); - // B[K, N] + // B[K0, N0, N1, N2, K1] static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); + // C[M, N] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index e07fa58076..0f4d9f243d 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -545,7 +545,7 @@ struct MfmaSelector selected_mfma.k_per_blk; } - static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; } + static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; } }; template @@ -708,7 +708,7 @@ struct XdlopsGemm static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); - static constexpr auto K1PerXdlops = mfma.GetKPerThread(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()