diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 4181f4cba7..61d01a9596 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -518,6 +518,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // main body index_t k0_block_data_begin = 0; + c_thread_buf.Clear(); + if constexpr(HasMainKBlockLoop) { do diff --git a/composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp b/composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp index ed3ae201fc..6924f20b7c 100644 --- a/composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp +++ b/composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp @@ -22,6 +22,13 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray static constexpr index_t vector_size = GetVectorSize(); + __host__ __device__ static constexpr index_t GetNumVectors() { return N; } + + __host__ __device__ static constexpr index_t GetNumElements() + { + return GetVectorSize() * GetNumVectors(); + } + VecBaseType invalid_element_value_ = VecBaseType{0}; T invalid_vec_value_ = T{0}; @@ -91,6 +98,12 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray return GetElement(i, true); } + __host__ __device__ void Clear() + { + static_for<0, GetNumElements(), 1>{}( + [&](auto i) { GetElement(i, true) = invalid_element_value_; }); + } + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }