mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
add interface in warp_gemm_impl
This commit is contained in:
@@ -92,6 +92,39 @@ struct WarpGemmImpl
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
|
||||
template <typename CTensor,
|
||||
typename ATensor,
|
||||
typename BTensor,
|
||||
index_t opselA,
|
||||
index_t opselB,
|
||||
bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CTensor& c,
|
||||
const ATensor& a,
|
||||
const BTensor& b,
|
||||
const int32_t& a_scale,
|
||||
const int32_t& b_scale,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
WarpGemmAttribute{}.template operator()<opselA, opselB>(
|
||||
c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant<post_nop_>{});
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
|
||||
template <typename ATensor, typename BTensor>
|
||||
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
|
||||
{
|
||||
@@ -116,6 +149,35 @@ struct WarpGemmImpl
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
template <typename ATensor, typename BTensor, index_t opselA, index_t opselB>
|
||||
CK_TILE_DEVICE auto operator()(const ATensor& a,
|
||||
const BTensor& b,
|
||||
const int32_t& a_scale,
|
||||
const int32_t& b_scale) const
|
||||
{
|
||||
using CTensor = CWarpTensor;
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
CTensor c;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
auto c_vec =
|
||||
WarpGemmAttribute{}.template operator()<opselA, opselB>(a_vec, a_scale, b_vec, b_scale);
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user