Add blockwise gemm to ck wrapper (#1139)

* Add blockwise gemm to ck wrapper

* Add blockwise gemm traits

* Disable test_gemm for non xdl devices

* Fixes

* Add c layout descritpions
This commit is contained in:
Bartłomiej Kocot
2024-01-31 21:24:40 +01:00
committed by GitHub
parent 6651a124cc
commit f3b6c23ac5
12 changed files with 1064 additions and 116 deletions

View File

@@ -6,6 +6,7 @@
#include "tensor_utils.hpp"
#include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
@@ -14,6 +15,8 @@ namespace wrapper {
namespace {
namespace detail {
/**
* \brief Calculate shape for partition based on number of threads per each dim and
* previous shape
@@ -30,26 +33,109 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
const auto slice_len = size<num_i>(shape) / thread_lengths.At(num_i);
const auto slice_len =
ck::math::integer_divide_ceil(size<num_i>(shape), thread_lengths.At(num_i));
return slice_len;
},
Number<Tuple<Ls...>::Size()>{});
}
/**
* \brief Apply projection.
*
* \param base_tuple Tuple to apply projection.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \return Multi index after projection.
*/
template <typename MultiIndex, typename ProjectionTuple>
__host__ __device__ constexpr auto
ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
[[maybe_unused]] const ProjectionTuple& projection)
{
if constexpr(is_same_v<ProjectionTuple, Tuple<>>)
{
return Tuple<>{};
}
else
{
auto base_tuple_after_projection = generate_tuple(
[&](auto i) {
const auto i_num = Number<i.value>{};
static_assert(
is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value ||
is_same_v<tuple_element_t<i_num, ProjectionTuple>, Number<1>>);
if constexpr(is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value)
{
// When slice (to remove), then insert empty tuple (will be removed in next
// step).
return Tuple<>{};
}
else
{
return base_tuple.At(i_num);
}
},
Number<MultiIndex::Size()>{});
// Remove empty tuples
return UnrollNestedTuple<0, 1>(base_tuple_after_projection);
}
}
/**
* \brief Calculate shape with dims from projection.
*
* \param shape Base tensor shape.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \return Shape with dims from projection
*/
template <typename... Ts, typename... Ps>
__host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts...>& shape,
const Tuple<Ps...>& projection)
{
return generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_slice, tuple_element_t<i, Tuple<Ps...>>>::value)
{
return size<i>(projection).to_;
}
else
{
// number of shape element in actual fragment of shape and projection (method to
// calculate shape idx)
constexpr index_t shape_i =
detail::ApplyProjection(TupleSlice<0, i>(Tuple<Ts...>{}),
TupleSlice<0, i>(Tuple<Ps...>{}))
.Size();
return size<shape_i>(shape);
}
},
Number<Tuple<Ps...>::Size()>{});
}
/**
* \brief Calculate total number of blocks.
*
* \param shape Base tensor shape.
* \param tile_shape Tile shape.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tuple with blocks number.
*/
template <typename... Ts, typename... Ls>
template <typename... Ts, typename... Ls, typename... Ps>
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
const Tuple<Ls...>& tile_shape)
const Tuple<Ls...>& tile_shape,
const Tuple<Ps...>& projection)
{
static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), "Wrong thread_lengths shape.");
return generate_tuple([&](auto i) { return size<i>(shape) / size<i>(tile_shape); },
Number<Tuple<Ls...>::Size()>{});
auto shape_with_projection = CalculateShapeWithProjection(shape, projection);
return generate_tuple(
[&](auto i) {
return ck::math::integer_divide_ceil(size<i>(shape_with_projection),
size<i>(tile_shape));
},
Number<Tuple<Ls...>::Size()>{});
}
/**
@@ -69,8 +155,75 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
return thread_idxs * partition_lengths_seq + old_offset_idxs;
}
/**
* \brief Calculate default projection.
*
* \param tile_shape Tile shape.
* \return Default projection (filled with Number<1>{}).
*/
template <typename TileShape>
__host__ __device__ constexpr auto
GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
{
return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
}
} // namespace detail
} // namespace
/**
* \brief Create local partition for thread (At now only packed partition
* is supported).
*
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple, typename ProjectionTuple>
__host__ __device__ constexpr auto
make_local_partition(TensorType& tensor,
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
const index_t thread_id,
const ProjectionTuple& projection)
{
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
// Calculate new partition shape
const auto& tensor_shape = shape(tensor);
// Calculate projected thread lengths
constexpr auto projected_thread_lengths =
detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{});
constexpr auto partition_shape =
detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
// Create Thread Cluster Descriptor
constexpr auto partition_shape_seq =
generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
Number<decltype(partition_shape)::Size()>{});
constexpr auto thread_lengths_seq =
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
// Calculate thread idxs and offsets
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
// Apply projection on thread idxs to remove not needed idxs
const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(unrolled_desc)>(
partition_shape, unrolled_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
// Apply offsets
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return partition_tensor;
}
/**
* \brief Create local partition for thread (At now only packed partition
* is supported).
@@ -81,37 +234,12 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple>
__host__ __device__ constexpr auto
make_local_partition(TensorType& tensor,
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
const index_t thread_id)
__host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
const ThreadLengthsTuple& thread_lengths,
const index_t thread_id)
{
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
// Calculate new partition shape
const auto& tensor_shape = shape(tensor);
constexpr auto partition_shape =
CalculateLocalPartitionShape(decltype(tensor_shape){}, ThreadLengthsTuple{});
// Create Thread Cluster Descriptor
constexpr auto partition_lengths_seq = generate_sequence_v2(
[&](auto I) { return size<I>(partition_shape); }, Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_lengths_seq =
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
// Calculate thread idxs and offsets
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
const auto offset_multi_idxs =
CalculateOffsetMultiIdxs(thread_idxs, partition_lengths_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
auto& flatten_desc = layout(tensor).GetUnrolledDescriptor();
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(flatten_desc)>(
partition_shape, flatten_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
// Apply offsets
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return partition_tensor;
const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{});
return make_local_partition(tensor, thread_lengths, thread_id, projection);
}
/**
@@ -125,22 +253,29 @@ make_local_partition(TensorType& tensor,
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \return Tile tensor.
*/
template <typename TensorType, typename BlockShapeTuple>
__host__ __device__ constexpr auto
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
template <typename TensorType, typename BlockShapeTuple, typename ProjectionTuple>
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
const BlockShapeTuple& tile_shape,
const index_t block_id,
const ProjectionTuple& projection)
{
static_assert(!IsNestedTuple(BlockShapeTuple{}));
constexpr bool is_default_projection =
is_same_v<ProjectionTuple, decltype(detail::GenerateDefaultProjection(BlockShapeTuple{}))>;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
if constexpr(BlockShapeTuple::Size() == I2)
// TODO: Enable block_2_tile_map partitioning for non-default projection.
if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection)
{
// Optimized version for 2d tile shape [MxK]
const auto block_2_tile_map =
@@ -169,20 +304,24 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con
{
// Calculate offsets
// Sequence with data to process per block
constexpr auto tile_shape_seq =
generate_sequence_v2([](auto I) { return size(BlockShapeTuple{}.At(I)); },
Number<BlockShapeTuple::Size()>{});
constexpr auto projected_tile_shape =
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
using ProjectedTileShapeTuple = decltype(projected_tile_shape);
constexpr auto projected_tile_shape_seq =
generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
Number<ProjectedTileShapeTuple::Size()>{});
// Tuple with number of blocks
const auto block_lengths = CalculateGridSize(shape(tensor), tile_shape);
constexpr auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection);
const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
const auto block_idxs =
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
const auto offset_multi_idxs =
CalculateOffsetMultiIdxs(block_idxs, tile_shape_seq, tensor.GetMultiIdxOffsets());
const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection);
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
const auto tile_layout =
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
aligned_desc);
Layout<remove_reference_t<ProjectedTileShapeTuple>, decltype(aligned_desc)>(
projected_tile_shape, aligned_desc);
auto tile_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
// Apply offsets
@@ -191,5 +330,61 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con
}
}
/**
* \brief Create local tile for thread block. (At now only packed tile
* is supported).
*
* \note Currently to get the best performance please use 2d shape.
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \return Tile tensor.
*/
template <typename TensorType, typename BlockShapeTuple>
__host__ __device__ constexpr auto
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
{
const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
return make_local_tile(tensor, tile_shape, block_id, projection);
}
/**
* \brief Pad tensor shapes to be adjusted to tile lengths.
*
*
* \param tensor Tensor to pad.
* \param tile_lengths Tile lengths to align tensor shape.
* \return Padded tensor.
*/
template <typename TensorType, typename TileLengths>
__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths)
{
const auto& tensor_shape = shape(tensor);
using TensorShapeType = remove_reference_t<decltype(tensor_shape)>;
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
// Generate sequence with ones to mark that all dims will be padded
constexpr auto do_pads_seq =
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<TensorShapeType::Size()>{});
// Create descriptor with padding
auto padded_desc =
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
// Generate padded shape
const auto padded_shape = generate_tuple(
[&](auto i) {
const auto& dim = size<i>(tensor_shape);
const auto& tile_length = size<i>(tile_lengths);
return ck::math::integer_divide_ceil(dim, tile_length) * tile_length;
},
Number<TileLengths::Size()>{});
// Create layout and tensor
const auto padded_layout =
Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), padded_layout);
partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets());
return partition_tensor;
}
} // namespace wrapper
} // namespace ck

View File

@@ -5,6 +5,7 @@
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
@@ -19,9 +20,9 @@ namespace wrapper {
* \brief Memory type, allowed members:
* - Generic,
* - Global,
* - LDS,
* - SGPR,
* - VGPR,
* - Lds,
* - Sgpr,
* - Vgpr,
*/
using MemoryTypeEnum = AddressSpaceEnum;
@@ -52,12 +53,8 @@ struct Slice
__host__ __device__ constexpr auto range(const T& dim) const
{
if constexpr(is_same_v<FromType, index_t> || is_same_v<ToType, index_t> ||
is_same_v<T, index_t>)
is_same_v<std::remove_const_t<T>, index_t>)
{
if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_)))
{
throw std::runtime_error("Invalid range");
}
if(to_ < 0)
{
return dim - from_ + to_ + 1;
@@ -70,9 +67,10 @@ struct Slice
}
else
{
static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_),
static_assert(T{} >= ToType{} && FromType{} >= Number<0>{} &&
(ToType{} < 0 || ToType{} > FromType{}),
"Invalid range");
if constexpr(to_ < 0)
if constexpr(ToType{} < 0)
{
return dim - from_ + to_ + Number<1>{};
}
@@ -130,6 +128,23 @@ constexpr auto make_register_tensor(const Layout<Shape, UnrolledDescriptorType>&
return Tensor<MemoryType, ElementType, Shape, UnrolledDescriptorType>(layout);
}
/**
* \brief Clear tensor. (Only for Vpgr/Sgpr)
*
* \param tensor Tensor to be cleared.
*/
template <MemoryTypeEnum BufferAddressSpace,
typename ElementType,
typename Shape,
typename UnrolledDescriptorType>
__host__ __device__ void
clear(Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>& tensor)
{
static_assert(
!Tensor<BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType>::IsDynamicBuffer);
return tensor.GetBuffer().Clear();
}
/**
* \brief Get Tensor Layout.
*