add interface in warp_gemm_impl

This commit is contained in:
Gino Lu
2025-09-12 02:53:50 -05:00
committed by mtgu0705
parent c6135f6abe
commit 8b1dd60c08

View File

@@ -92,6 +92,39 @@ struct WarpGemmImpl
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
template <typename CTensor,
typename ATensor,
typename BTensor,
index_t opselA,
index_t opselB,
bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CTensor& c,
const ATensor& a,
const BTensor& b,
const int32_t& a_scale,
const int32_t& b_scale,
bool_constant<post_nop_> = {}) const
{
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec
WarpGemmAttribute{}.template operator()<opselA, opselB>(
c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
template <typename ATensor, typename BTensor>
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{
@@ -116,6 +149,35 @@ struct WarpGemmImpl
return c;
}
template <typename ATensor, typename BTensor, index_t opselA, index_t opselB>
CK_TILE_DEVICE auto operator()(const ATensor& a,
const BTensor& b,
const int32_t& a_scale,
const int32_t& b_scale) const
{
using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
CTensor c;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
// c_vec = a_vec * b_vec
auto c_vec =
WarpGemmAttribute{}.template operator()<opselA, opselB>(a_vec, a_scale, b_vec, b_scale);
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
return c;
}
};
} // namespace ck_tile