diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index 38fd0d408b..145f635a2b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -92,6 +92,39 @@ struct WarpGemmImpl c.get_thread_buffer().template set_as(I0, c_vec); } + template + CK_TILE_DEVICE void operator()(CTensor& c, + const ATensor& a, + const BTensor& b, + const int32_t& a_scale, + const int32_t& b_scale, + bool_constant = {}) const + { + static_assert(detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v); + using AVec = ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; + + constexpr auto I0 = number<0>{}; + + const auto a_vec = a.get_thread_buffer().template get_as()[I0]; + const auto b_vec = b.get_thread_buffer().template get_as()[I0]; + auto c_vec = c.get_thread_buffer().template get_as()[I0]; + + // c_vec += a_vec * b_vec + WarpGemmAttribute{}.template operator()( + c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant{}); + + c.get_thread_buffer().template set_as(I0, c_vec); + } + template CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const { @@ -116,6 +149,35 @@ struct WarpGemmImpl return c; } + + template + CK_TILE_DEVICE auto operator()(const ATensor& a, + const BTensor& b, + const int32_t& a_scale, + const int32_t& b_scale) const + { + using CTensor = CWarpTensor; + static_assert(detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v); + CTensor c; + + using AVec = ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; + + constexpr auto I0 = number<0>{}; + + const auto a_vec = a.get_thread_buffer().template get_as()[I0]; + const auto b_vec = b.get_thread_buffer().template get_as()[I0]; + + // c_vec = a_vec * b_vec + auto c_vec = + WarpGemmAttribute{}.template operator()(a_vec, a_scale, b_vec, b_scale); + + c.get_thread_buffer().template set_as(I0, c_vec); + + return c; + } }; } // namespace ck_tile