mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add optimized blockwise gemm using ck wrapper (#1157)
* Add optimized blockwise gemm using ck wrapper * Add basic gemm example * Update docs * Add tutorial for gemm using ck wrapper * Add perf note * edits * Fix cmake * Fixes --------- Co-authored-by: Lisa Delaney <lisa.delaney@amd.com>
This commit is contained in:
@@ -61,12 +61,12 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
|
||||
decltype(dim_access_order),
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
Sequence<false>,
|
||||
Sequence<false>>{in_grid_desc,
|
||||
make_tuple(src_tensor.GetMultiIdxOffsets()),
|
||||
out_grid_desc,
|
||||
make_tuple(dst_tensor.GetMultiIdxOffsets()),
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
Sequence<true>,
|
||||
Sequence<true>>{in_grid_desc,
|
||||
make_tuple(src_tensor.GetMultiIdxOffsets()),
|
||||
out_grid_desc,
|
||||
make_tuple(dst_tensor.GetMultiIdxOffsets()),
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
transfer.Run(tie(in_grid_desc),
|
||||
tie(src_tensor.GetBuffer()),
|
||||
@@ -104,37 +104,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
|
||||
else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
|
||||
{
|
||||
// Perform copy from DynamicBuffer to StaticBuffer
|
||||
const auto src_dst_slice_origin =
|
||||
const auto dst_slice_origin_idxs =
|
||||
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
|
||||
constexpr auto src_vector_tensor_lengths = generate_sequence_v2(
|
||||
[&](auto I) {
|
||||
if constexpr(I == VectorDim)
|
||||
{
|
||||
return Number<ScalarPerVector>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
},
|
||||
Number<num_dims>{});
|
||||
|
||||
auto transfer =
|
||||
ThreadwiseTensorSliceTransfer_v4r1<typename SrcTensorType::TensorElementType,
|
||||
typename DstTensorType::TensorElementType,
|
||||
remove_cvref_t<decltype(in_grid_desc)>,
|
||||
remove_cvref_t<decltype(out_grid_desc)>,
|
||||
decltype(thread_slice_lengths),
|
||||
decltype(dim_access_order),
|
||||
decltype(src_vector_tensor_lengths),
|
||||
decltype(dim_access_order)>{
|
||||
src_tensor.GetMultiIdxOffsets()};
|
||||
auto transfer = ThreadwiseTensorSliceTransfer_v2<
|
||||
std::remove_const_t<typename SrcTensorType::TensorElementType>,
|
||||
std::remove_const_t<typename DstTensorType::TensorElementType>,
|
||||
remove_cvref_t<decltype(in_grid_desc)>,
|
||||
remove_cvref_t<decltype(out_grid_desc)>,
|
||||
decltype(thread_slice_lengths),
|
||||
decltype(dim_access_order),
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
I1,
|
||||
false,
|
||||
false>{in_grid_desc, src_tensor.GetMultiIdxOffsets()};
|
||||
|
||||
transfer.Run(in_grid_desc,
|
||||
src_dst_slice_origin,
|
||||
src_tensor.GetBuffer(),
|
||||
out_grid_desc,
|
||||
src_dst_slice_origin,
|
||||
dst_slice_origin_idxs,
|
||||
dst_tensor.GetBuffer());
|
||||
}
|
||||
else
|
||||
@@ -183,10 +171,12 @@ template <typename DimAccessOrderTuple,
|
||||
index_t ScalarPerVector,
|
||||
typename SrcTensorType,
|
||||
typename DstTensorType,
|
||||
typename ThreadLayoutTuple>
|
||||
__device__ void blockwise_copy(const SrcTensorType& src_tensor,
|
||||
DstTensorType& dst_tensor,
|
||||
[[maybe_unused]] ThreadLayoutTuple& thread_layout)
|
||||
typename ThreadShape,
|
||||
typename ThreadUnrolledDesc>
|
||||
__device__ void
|
||||
blockwise_copy(const SrcTensorType& src_tensor,
|
||||
DstTensorType& dst_tensor,
|
||||
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout)
|
||||
{
|
||||
static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
|
||||
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
|
||||
@@ -199,12 +189,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor,
|
||||
|
||||
constexpr auto tile_lengths_seq =
|
||||
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
|
||||
constexpr auto thread_layout_seq = generate_sequence_v2(
|
||||
[](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number<num_dims>{});
|
||||
constexpr auto thread_layout_seq =
|
||||
generate_sequence_v2([](auto I) { return size<I>(ThreadShape{}); }, Number<num_dims>{});
|
||||
constexpr auto dim_access_order = generate_sequence_v2(
|
||||
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<size(ThreadLayoutTuple{})>;
|
||||
using ThisThreadBlock = ThisThreadBlock<size(ThreadShape{})>;
|
||||
|
||||
// Perform copy between DynamicBuffers
|
||||
auto transfer = ThreadGroupTensorSliceTransfer_v7<
|
||||
|
||||
Reference in New Issue
Block a user