mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add optimized blockwise gemm using ck wrapper (#1157)
* Add optimized blockwise gemm using ck wrapper * Add basic gemm example * Update docs * Add tutorial for gemm using ck wrapper * Add perf note * edits * Fix cmake * Fixes --------- Co-authored-by: Lisa Delaney <lisa.delaney@amd.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
@@ -29,6 +30,7 @@ template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
namespace {
|
||||
namespace detail {
|
||||
/**
|
||||
* \brief Generate packed (column-major) strides if not passed
|
||||
*
|
||||
@@ -83,6 +85,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
|
||||
/// @endcond
|
||||
@@ -98,8 +101,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
|
||||
{
|
||||
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
|
||||
using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape,
|
||||
detail::MakeUnrolledDescriptor(shape, strides));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -112,13 +116,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
|
||||
template <typename Shape>
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape)
|
||||
{
|
||||
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
|
||||
using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape,
|
||||
detail::MakeUnrolledDescriptor(shape, Tuple<>{}));
|
||||
}
|
||||
|
||||
// Layout helpers
|
||||
// get
|
||||
|
||||
/**
|
||||
* \private
|
||||
* \brief Get dim.
|
||||
@@ -152,8 +155,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
|
||||
* \param layout Layout to create sub layout.
|
||||
* \return Requsted sub layout.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename FlattenDesc>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
|
||||
template <index_t idx, typename Shape, typename UnrolledDesc>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, UnrolledDesc>& layout)
|
||||
{
|
||||
const auto& shape = layout.GetShape();
|
||||
const auto new_shape = get<idx>(shape);
|
||||
@@ -427,5 +430,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout)
|
||||
return layout.GetShape();
|
||||
}
|
||||
|
||||
// pad
|
||||
/**
|
||||
* \brief Pad layout shapes to be adjusted to tile lengths.
|
||||
*
|
||||
*
|
||||
* \param layout Layout to pad.
|
||||
* \param tile_lengths Tile lengths to align layout shape.
|
||||
* \return Padded layout.
|
||||
*/
|
||||
template <typename Shape, typename UnrolledDesc, typename TileLengths>
|
||||
__host__ __device__ constexpr auto pad(const Layout<Shape, UnrolledDesc>& layout,
|
||||
const TileLengths& tile_lengths)
|
||||
{
|
||||
auto& unrolled_desc = layout.GetUnrolledDescriptor();
|
||||
// Generate sequence with ones to mark that all dims will be padded
|
||||
constexpr auto do_pads_seq =
|
||||
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<Shape::Size()>{});
|
||||
// Create descriptor with padding
|
||||
auto padded_desc =
|
||||
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
|
||||
// Generate padded shape
|
||||
const auto padded_shape = generate_tuple(
|
||||
[&](auto i) { return padded_desc.GetLength(Number<i>{}); }, Number<TileLengths::Size()>{});
|
||||
// Create layout
|
||||
return Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
|
||||
}
|
||||
|
||||
// unmerge
|
||||
/**
|
||||
* \brief Unmerge selected dim in layout.
|
||||
*
|
||||
* \tparam Idx Index to dimension being unmerged.
|
||||
* \param layout Layout to pad.
|
||||
* \param new_lengths Dimensions into which the indicated dimension will be divided.
|
||||
* \param new_indexes Indexes to shuffle dims. Dims for unmerged dim should be nested.
|
||||
* \return Unmerged layout.
|
||||
*/
|
||||
template <index_t Idx, typename Shape, typename UnrolledDesc, typename NewLengths, typename NewIdxs>
|
||||
__host__ __device__ constexpr auto unmerge(const Layout<Shape, UnrolledDesc>& layout,
|
||||
const NewLengths& new_lengths,
|
||||
[[maybe_unused]] const NewIdxs& new_indexes)
|
||||
{
|
||||
const auto& layout_shape = shape(layout);
|
||||
auto& unrolled_desc = layout.GetUnrolledDescriptor();
|
||||
constexpr auto dims = Shape::Size();
|
||||
// Generate transforms
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == Idx)
|
||||
{
|
||||
return make_unmerge_transform(new_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(layout_shape.At(i));
|
||||
}
|
||||
},
|
||||
Number<dims>{});
|
||||
|
||||
constexpr auto lower_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
|
||||
constexpr auto upper_dims = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>::value)
|
||||
{
|
||||
constexpr auto idxs_tuple = tuple_element_t<i.value, NewIdxs>{};
|
||||
return to_sequence(idxs_tuple);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t index = tuple_element_t<i.value, NewIdxs>{};
|
||||
return Sequence<index>{};
|
||||
}
|
||||
},
|
||||
Number<dims>{});
|
||||
|
||||
const auto unmerged_desc =
|
||||
transform_tensor_descriptor(unrolled_desc, transforms, lower_dims, upper_dims);
|
||||
const auto unmerged_shape =
|
||||
generate_tuple([&](auto i) { return unmerged_desc.GetLength(Number<i>{}); },
|
||||
Number<decltype(unmerged_desc)::GetNumOfVisibleDimension()>{});
|
||||
|
||||
// Create layout
|
||||
return Layout<decltype(unmerged_shape), decltype(unmerged_desc)>(unmerged_shape, unmerged_desc);
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user