diff --git a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp index ca116ef17e..36de380719 100644 --- a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp @@ -89,6 +89,16 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 } } + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf); + } + } + template __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) { diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index 5b6a94a0b7..4f531b55ae 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -150,7 +150,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk constexpr index_t GemmMPerWave = 64; constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 1; constexpr index_t NRepeat = 1; @@ -164,7 +164,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;