mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add blockwise gemm to ck wrapper (#1139)
* Add blockwise gemm to ck wrapper * Add blockwise gemm traits * Disable test_gemm for non xdl devices * Fixes * Add c layout descritpions
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "tensor_utils.hpp"
|
||||
#include "layout_utils.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
|
||||
@@ -14,6 +15,8 @@ namespace wrapper {
|
||||
|
||||
namespace {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/**
|
||||
* \brief Calculate shape for partition based on number of threads per each dim and
|
||||
* previous shape
|
||||
@@ -30,26 +33,109 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
const auto slice_len = size<num_i>(shape) / thread_lengths.At(num_i);
|
||||
const auto slice_len =
|
||||
ck::math::integer_divide_ceil(size<num_i>(shape), thread_lengths.At(num_i));
|
||||
return slice_len;
|
||||
},
|
||||
Number<Tuple<Ls...>::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Apply projection.
|
||||
*
|
||||
* \param base_tuple Tuple to apply projection.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \return Multi index after projection.
|
||||
*/
|
||||
template <typename MultiIndex, typename ProjectionTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
|
||||
[[maybe_unused]] const ProjectionTuple& projection)
|
||||
{
|
||||
if constexpr(is_same_v<ProjectionTuple, Tuple<>>)
|
||||
{
|
||||
return Tuple<>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
auto base_tuple_after_projection = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto i_num = Number<i.value>{};
|
||||
static_assert(
|
||||
is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value ||
|
||||
is_same_v<tuple_element_t<i_num, ProjectionTuple>, Number<1>>);
|
||||
if constexpr(is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value)
|
||||
{
|
||||
// When slice (to remove), then insert empty tuple (will be removed in next
|
||||
// step).
|
||||
return Tuple<>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return base_tuple.At(i_num);
|
||||
}
|
||||
},
|
||||
Number<MultiIndex::Size()>{});
|
||||
// Remove empty tuples
|
||||
return UnrollNestedTuple<0, 1>(base_tuple_after_projection);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate shape with dims from projection.
|
||||
*
|
||||
* \param shape Base tensor shape.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \return Shape with dims from projection
|
||||
*/
|
||||
template <typename... Ts, typename... Ps>
|
||||
__host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts...>& shape,
|
||||
const Tuple<Ps...>& projection)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(is_detected<is_slice, tuple_element_t<i, Tuple<Ps...>>>::value)
|
||||
{
|
||||
return size<i>(projection).to_;
|
||||
}
|
||||
else
|
||||
{
|
||||
// number of shape element in actual fragment of shape and projection (method to
|
||||
// calculate shape idx)
|
||||
constexpr index_t shape_i =
|
||||
detail::ApplyProjection(TupleSlice<0, i>(Tuple<Ts...>{}),
|
||||
TupleSlice<0, i>(Tuple<Ps...>{}))
|
||||
.Size();
|
||||
return size<shape_i>(shape);
|
||||
}
|
||||
},
|
||||
Number<Tuple<Ps...>::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate total number of blocks.
|
||||
*
|
||||
* \param shape Base tensor shape.
|
||||
* \param tile_shape Tile shape.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Tuple with blocks number.
|
||||
*/
|
||||
template <typename... Ts, typename... Ls>
|
||||
template <typename... Ts, typename... Ls, typename... Ps>
|
||||
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
|
||||
const Tuple<Ls...>& tile_shape)
|
||||
const Tuple<Ls...>& tile_shape,
|
||||
const Tuple<Ps...>& projection)
|
||||
{
|
||||
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
|
||||
return generate_tuple([&](auto i) { return size<i>(shape) / size<i>(tile_shape); },
|
||||
Number<Tuple<Ls...>::Size()>{});
|
||||
auto shape_with_projection = CalculateShapeWithProjection(shape, projection);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return ck::math::integer_divide_ceil(size<i>(shape_with_projection),
|
||||
size<i>(tile_shape));
|
||||
},
|
||||
Number<Tuple<Ls...>::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -69,8 +155,75 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
|
||||
return thread_idxs * partition_lengths_seq + old_offset_idxs;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate default projection.
|
||||
*
|
||||
* \param tile_shape Tile shape.
|
||||
* \return Default projection (filled with Number<1>{}).
|
||||
*/
|
||||
template <typename TileShape>
|
||||
__host__ __device__ constexpr auto
|
||||
GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
|
||||
{
|
||||
return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* \brief Create local partition for thread (At now only packed partition
|
||||
* is supported).
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param thread_lengths Layout of threads (could not be nested).
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Partition tensor.
|
||||
*/
|
||||
template <typename TensorType, typename ThreadLengthsTuple, typename ProjectionTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_partition(TensorType& tensor,
|
||||
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
|
||||
const index_t thread_id,
|
||||
const ProjectionTuple& projection)
|
||||
{
|
||||
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
|
||||
// Calculate new partition shape
|
||||
const auto& tensor_shape = shape(tensor);
|
||||
// Calculate projected thread lengths
|
||||
constexpr auto projected_thread_lengths =
|
||||
detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{});
|
||||
constexpr auto partition_shape =
|
||||
detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
|
||||
// Create Thread Cluster Descriptor
|
||||
constexpr auto partition_shape_seq =
|
||||
generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
|
||||
Number<decltype(partition_shape)::Size()>{});
|
||||
constexpr auto thread_lengths_seq =
|
||||
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
|
||||
Number<ThreadLengthsTuple::Size()>{});
|
||||
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
|
||||
// Calculate thread idxs and offsets
|
||||
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
// Apply projection on thread idxs to remove not needed idxs
|
||||
const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
|
||||
// Create new layout and tensor
|
||||
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
|
||||
const auto partition_layout =
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(unrolled_desc)>(
|
||||
partition_shape, unrolled_desc);
|
||||
auto partition_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
|
||||
// Apply offsets
|
||||
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
|
||||
return partition_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create local partition for thread (At now only packed partition
|
||||
* is supported).
|
||||
@@ -81,37 +234,12 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
|
||||
* \return Partition tensor.
|
||||
*/
|
||||
template <typename TensorType, typename ThreadLengthsTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_partition(TensorType& tensor,
|
||||
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
|
||||
const index_t thread_id)
|
||||
__host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
|
||||
const ThreadLengthsTuple& thread_lengths,
|
||||
const index_t thread_id)
|
||||
{
|
||||
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
|
||||
// Calculate new partition shape
|
||||
const auto& tensor_shape = shape(tensor);
|
||||
constexpr auto partition_shape =
|
||||
CalculateLocalPartitionShape(decltype(tensor_shape){}, ThreadLengthsTuple{});
|
||||
// Create Thread Cluster Descriptor
|
||||
constexpr auto partition_lengths_seq = generate_sequence_v2(
|
||||
[&](auto I) { return size<I>(partition_shape); }, Number<ThreadLengthsTuple::Size()>{});
|
||||
constexpr auto thread_lengths_seq =
|
||||
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
|
||||
Number<ThreadLengthsTuple::Size()>{});
|
||||
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
|
||||
// Calculate thread idxs and offsets
|
||||
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
const auto offset_multi_idxs =
|
||||
CalculateOffsetMultiIdxs(thread_idxs, partition_lengths_seq, tensor.GetMultiIdxOffsets());
|
||||
// Create new layout and tensor
|
||||
auto& flatten_desc = layout(tensor).GetUnrolledDescriptor();
|
||||
const auto partition_layout =
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(flatten_desc)>(
|
||||
partition_shape, flatten_desc);
|
||||
auto partition_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
|
||||
// Apply offsets
|
||||
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
|
||||
return partition_tensor;
|
||||
const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{});
|
||||
return make_local_partition(tensor, thread_lengths, thread_id, projection);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -125,22 +253,29 @@ make_local_partition(TensorType& tensor,
|
||||
* \param tensor Tensor for partition.
|
||||
* \param tile_shape Shapes of requested tile.
|
||||
* \param block_id Block index represented as integer.
|
||||
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \return Tile tensor.
|
||||
*/
|
||||
template <typename TensorType, typename BlockShapeTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
|
||||
template <typename TensorType, typename BlockShapeTuple, typename ProjectionTuple>
|
||||
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
const BlockShapeTuple& tile_shape,
|
||||
const index_t block_id,
|
||||
const ProjectionTuple& projection)
|
||||
{
|
||||
static_assert(!IsNestedTuple(BlockShapeTuple{}));
|
||||
|
||||
constexpr bool is_default_projection =
|
||||
is_same_v<ProjectionTuple, decltype(detail::GenerateDefaultProjection(BlockShapeTuple{}))>;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
|
||||
|
||||
if constexpr(BlockShapeTuple::Size() == I2)
|
||||
// TODO: Enable block_2_tile_map partitioning for non-default projection.
|
||||
if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection)
|
||||
{
|
||||
// Optimized version for 2d tile shape [MxK]
|
||||
const auto block_2_tile_map =
|
||||
@@ -169,20 +304,24 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con
|
||||
{
|
||||
// Calculate offsets
|
||||
// Sequence with data to process per block
|
||||
constexpr auto tile_shape_seq =
|
||||
generate_sequence_v2([](auto I) { return size(BlockShapeTuple{}.At(I)); },
|
||||
Number<BlockShapeTuple::Size()>{});
|
||||
constexpr auto projected_tile_shape =
|
||||
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
|
||||
using ProjectedTileShapeTuple = decltype(projected_tile_shape);
|
||||
constexpr auto projected_tile_shape_seq =
|
||||
generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
|
||||
Number<ProjectedTileShapeTuple::Size()>{});
|
||||
// Tuple with number of blocks
|
||||
const auto block_lengths = CalculateGridSize(shape(tensor), tile_shape);
|
||||
constexpr auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
|
||||
const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection);
|
||||
const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
|
||||
const auto block_idxs =
|
||||
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
|
||||
const auto offset_multi_idxs =
|
||||
CalculateOffsetMultiIdxs(block_idxs, tile_shape_seq, tensor.GetMultiIdxOffsets());
|
||||
const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection);
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
|
||||
// Create new layout and tensor
|
||||
const auto tile_layout =
|
||||
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
|
||||
aligned_desc);
|
||||
Layout<remove_reference_t<ProjectedTileShapeTuple>, decltype(aligned_desc)>(
|
||||
projected_tile_shape, aligned_desc);
|
||||
auto tile_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
|
||||
// Apply offsets
|
||||
@@ -191,5 +330,61 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create local tile for thread block. (At now only packed tile
|
||||
* is supported).
|
||||
*
|
||||
* \note Currently to get the best performance please use 2d shape.
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param tile_shape Shapes of requested tile.
|
||||
* \param block_id Block index represented as integer.
|
||||
* \return Tile tensor.
|
||||
*/
|
||||
template <typename TensorType, typename BlockShapeTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
|
||||
{
|
||||
const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
|
||||
return make_local_tile(tensor, tile_shape, block_id, projection);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Pad tensor shapes to be adjusted to tile lengths.
|
||||
*
|
||||
*
|
||||
* \param tensor Tensor to pad.
|
||||
* \param tile_lengths Tile lengths to align tensor shape.
|
||||
* \return Padded tensor.
|
||||
*/
|
||||
template <typename TensorType, typename TileLengths>
|
||||
__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths)
|
||||
{
|
||||
const auto& tensor_shape = shape(tensor);
|
||||
using TensorShapeType = remove_reference_t<decltype(tensor_shape)>;
|
||||
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
|
||||
// Generate sequence with ones to mark that all dims will be padded
|
||||
constexpr auto do_pads_seq =
|
||||
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<TensorShapeType::Size()>{});
|
||||
// Create descriptor with padding
|
||||
auto padded_desc =
|
||||
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
|
||||
// Generate padded shape
|
||||
const auto padded_shape = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& dim = size<i>(tensor_shape);
|
||||
const auto& tile_length = size<i>(tile_lengths);
|
||||
return ck::math::integer_divide_ceil(dim, tile_length) * tile_length;
|
||||
},
|
||||
Number<TileLengths::Size()>{});
|
||||
// Create layout and tensor
|
||||
const auto padded_layout =
|
||||
Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
|
||||
auto partition_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), padded_layout);
|
||||
partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets());
|
||||
return partition_tensor;
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
@@ -19,9 +20,9 @@ namespace wrapper {
|
||||
* \brief Memory type, allowed members:
|
||||
* - Generic,
|
||||
* - Global,
|
||||
* - LDS,
|
||||
* - SGPR,
|
||||
* - VGPR,
|
||||
* - Lds,
|
||||
* - Sgpr,
|
||||
* - Vgpr,
|
||||
*/
|
||||
using MemoryTypeEnum = AddressSpaceEnum;
|
||||
|
||||
@@ -52,12 +53,8 @@ struct Slice
|
||||
__host__ __device__ constexpr auto range(const T& dim) const
|
||||
{
|
||||
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
|
||||
is_same_v<T, index_t>)
|
||||
is_same_v<std::remove_const_t<T>, index_t>)
|
||||
{
|
||||
if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_)))
|
||||
{
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
if(to_ < 0)
|
||||
{
|
||||
return dim - from_ + to_ + 1;
|
||||
@@ -70,9 +67,10 @@ struct Slice
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_),
|
||||
static_assert(T{} >= ToType{} && FromType{} >= Number<0>{} &&
|
||||
(ToType{} < 0 || ToType{} > FromType{}),
|
||||
"Invalid range");
|
||||
if constexpr(to_ < 0)
|
||||
if constexpr(ToType{} < 0)
|
||||
{
|
||||
return dim - from_ + to_ + Number<1>{};
|
||||
}
|
||||
@@ -130,6 +128,23 @@ constexpr auto make_register_tensor(const Layout<Shape, UnrolledDescriptorType>&
|
||||
return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(layout);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Clear tensor. (Only for Vpgr/Sgpr)
|
||||
*
|
||||
* \param tensor Tensor to be cleared.
|
||||
*/
|
||||
template <MemoryTypeEnum BufferAddressSpace,
|
||||
typename ElementType,
|
||||
typename Shape,
|
||||
typename UnrolledDescriptorType>
|
||||
__host__ __device__ void
|
||||
clear(Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
|
||||
{
|
||||
static_assert(
|
||||
!Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>::IsDynamicBuffer);
|
||||
return tensor.GetBuffer().Clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get Tensor Layout.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user