diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index 4dc3303c39..1c9337db15 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); StaticBufferOfVectorTypeV2, + vector_type, MRepeat * NRepeat, true> c_thread_buf_; @@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; - return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N)); + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); } __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 7534215c04..4181f4cba7 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -557,6 +557,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // output: register to global memory { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -569,10 +572,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); - // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index const auto c_thread_mtx_on_block = diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index 68b4db1a43..e07fa58076 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -507,7 +507,7 @@ struct MfmaSelector static constexpr auto selected_mfma = mfma_type()>{}; - __host__ __device__ static constexpr void mfma_check() + __host__ __device__ constexpr MfmaSelector() { static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk == selected_mfma.num_regs_per_blk, @@ -533,8 +533,6 @@ struct MfmaSelector "is_k_reduction wrong!"); } - __host__ __device__ constexpr MfmaSelector() { mfma_check(); } - static constexpr bool IsABroadcast() { static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast"); @@ -621,6 +619,8 @@ struct XdlopsGemm return MPerXdlops * NPerXdlops / mfma_instr.wave_size; } + __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; } + template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index e0da35c7ba..070350fc0d 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -18,7 +18,7 @@ #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" -#define USE_DYNAMIC_MODE 0 +#define USE_DYNAMIC_MODE 1 #define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V6R1_NCHW 0