mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add optimized blockwise gemm using ck wrapper (#1157)
* Add optimized blockwise gemm using ck wrapper * Add basic gemm example * Update docs * Add tutorial for gemm using ck wrapper * Add perf note * edits * Fix cmake * Fixes --------- Co-authored-by: Lisa Delaney <lisa.delaney@amd.com>
This commit is contained in:
@@ -6,7 +6,6 @@
|
||||
#include "tensor_utils.hpp"
|
||||
#include "layout_utils.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
|
||||
@@ -44,8 +43,9 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
|
||||
* \brief Apply projection.
|
||||
*
|
||||
* \param base_tuple Tuple to apply projection.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Multi index after projection.
|
||||
*/
|
||||
template <typename MultiIndex, typename ProjectionTuple>
|
||||
@@ -73,7 +73,7 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
|
||||
}
|
||||
else
|
||||
{
|
||||
return base_tuple.At(i_num);
|
||||
return make_tuple(base_tuple.At(i_num));
|
||||
}
|
||||
},
|
||||
Number<MultiIndex::Size()>{});
|
||||
@@ -86,8 +86,9 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
|
||||
* \brief Calculate shape with dims from projection.
|
||||
*
|
||||
* \param shape Base tensor shape.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Shape with dims from projection
|
||||
*/
|
||||
template <typename... Ts, typename... Ps>
|
||||
@@ -119,22 +120,14 @@ __host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts..
|
||||
*
|
||||
* \param shape Base tensor shape.
|
||||
* \param tile_shape Tile shape.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Tuple with blocks number.
|
||||
*/
|
||||
template <typename... Ts, typename... Ls, typename... Ps>
|
||||
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
|
||||
const Tuple<Ls...>& tile_shape,
|
||||
const Tuple<Ps...>& projection)
|
||||
const Tuple<Ls...>& tile_shape)
|
||||
{
|
||||
auto shape_with_projection = CalculateShapeWithProjection(shape, projection);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return ck::math::integer_divide_ceil(size<i>(shape_with_projection),
|
||||
size<i>(tile_shape));
|
||||
},
|
||||
[&](auto i) { return ck::math::integer_divide_ceil(size<i>(shape), size<i>(tile_shape)); },
|
||||
Number<Tuple<Ls...>::Size()>{});
|
||||
}
|
||||
|
||||
@@ -155,6 +148,54 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
|
||||
return thread_idxs * partition_lengths_seq + old_offset_idxs;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Select dims to partition (skip if slice).
|
||||
*
|
||||
* \param block_idxs Input block indexes.
|
||||
* \return Partitioned dims.
|
||||
*/
|
||||
template <typename BlockIdxs>
|
||||
__host__ __device__ constexpr auto GetDimsToPartition([[maybe_unused]] const BlockIdxs& block_idxs)
|
||||
{
|
||||
const auto dims_to_partition = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
|
||||
{
|
||||
return Number<i>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Tuple<>{};
|
||||
}
|
||||
},
|
||||
Number<BlockIdxs::Size()>{});
|
||||
// Remove empty tuples
|
||||
return UnrollNestedTuple<0, 1>(dims_to_partition);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Replace slices with zeros (Slice dims are not partitioned).
|
||||
*
|
||||
* \param block_idxs Input block indexes.
|
||||
* \return Parsed dims.
|
||||
*/
|
||||
template <typename BlockIdxs>
|
||||
__host__ __device__ constexpr auto ReplaceSlicesWithZeros(const BlockIdxs& block_idxs)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
|
||||
{
|
||||
return block_idxs.At(i);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<0>{};
|
||||
}
|
||||
},
|
||||
Number<BlockIdxs::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate default projection.
|
||||
*
|
||||
@@ -168,6 +209,31 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
|
||||
return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate thread multi index from 1d thread index.
|
||||
*
|
||||
* \param thread_layout Layout of threads (could not be nested).
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \return Multi index.
|
||||
*/
|
||||
template <typename ThreadShape, typename ThreadUnrolledDesc>
|
||||
__host__ __device__ constexpr auto CalculateThreadMultiIdx(
|
||||
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
|
||||
const index_t thread_id)
|
||||
{
|
||||
static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1,
|
||||
"Thread layout should not be transformed.");
|
||||
constexpr auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{});
|
||||
constexpr auto shape = ThreadShape{};
|
||||
constexpr auto strides = embed_transform.coefficients_;
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
return (thread_id / strides.At(num_i)) % shape.At(num_i);
|
||||
},
|
||||
Number<ThreadShape::Size()>{});
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
|
||||
@@ -176,51 +242,62 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
|
||||
* is supported).
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param thread_lengths Layout of threads (could not be nested).
|
||||
* \param thread_layout Layout of threads (could not be transformed).
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Partition tensor.
|
||||
*/
|
||||
template <typename TensorType, typename ThreadLengthsTuple, typename ProjectionTuple>
|
||||
template <typename TensorType,
|
||||
typename ThreadShape,
|
||||
typename ThreadUnrolledDesc,
|
||||
typename ProjectionTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_partition(TensorType& tensor,
|
||||
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
|
||||
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
|
||||
const index_t thread_id,
|
||||
const ProjectionTuple& projection)
|
||||
{
|
||||
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
|
||||
static_assert(!IsNestedTuple(ThreadShape{}));
|
||||
// Calculate new partition shape
|
||||
const auto& tensor_shape = shape(tensor);
|
||||
// Calculate projected thread lengths
|
||||
constexpr auto projected_thread_lengths =
|
||||
detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{});
|
||||
detail::ApplyProjection(ThreadShape{}, ProjectionTuple{});
|
||||
constexpr auto partition_shape =
|
||||
detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
|
||||
// Create Thread Cluster Descriptor
|
||||
constexpr auto partition_shape_seq =
|
||||
generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
|
||||
Number<decltype(partition_shape)::Size()>{});
|
||||
constexpr auto thread_lengths_seq =
|
||||
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
|
||||
Number<ThreadLengthsTuple::Size()>{});
|
||||
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
|
||||
// Calculate thread idxs and offsets
|
||||
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id);
|
||||
// Apply projection on thread idxs to remove not needed idxs
|
||||
const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
|
||||
// Create new layout and tensor
|
||||
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
|
||||
// Slice descriptor
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_slice_transform(partition_shape.At(i),
|
||||
offset_multi_idxs.At(i),
|
||||
partition_shape.At(i) + offset_multi_idxs.At(i));
|
||||
},
|
||||
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
|
||||
const auto lower_upper_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; },
|
||||
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
|
||||
auto sliced_desc =
|
||||
transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
|
||||
// Create layout
|
||||
const auto partition_layout =
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(unrolled_desc)>(
|
||||
partition_shape, unrolled_desc);
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
|
||||
partition_shape, sliced_desc);
|
||||
auto partition_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
|
||||
// Apply offsets
|
||||
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
|
||||
return partition_tensor;
|
||||
}
|
||||
|
||||
@@ -233,12 +310,13 @@ make_local_partition(TensorType& tensor,
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \return Partition tensor.
|
||||
*/
|
||||
template <typename TensorType, typename ThreadLengthsTuple>
|
||||
__host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
|
||||
const ThreadLengthsTuple& thread_lengths,
|
||||
const index_t thread_id)
|
||||
template <typename TensorType, typename ThreadShape, typename ThreadUnrolledDesc>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_partition(TensorType& tensor,
|
||||
const Layout<ThreadShape, ThreadUnrolledDesc>& thread_lengths,
|
||||
const index_t thread_id)
|
||||
{
|
||||
const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{});
|
||||
const auto projection = detail::GenerateDefaultProjection(ThreadShape{});
|
||||
return make_local_partition(tensor, thread_lengths, thread_id, projection);
|
||||
}
|
||||
|
||||
@@ -252,21 +330,24 @@ __host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param tile_shape Shapes of requested tile.
|
||||
* \param block_id Block index represented as integer.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \param block_idxs Tuple of block indexes represented as integer. If slice,
|
||||
* then get whole dim.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Tile tensor.
|
||||
*/
|
||||
template <typename TensorType, typename BlockShapeTuple, typename ProjectionTuple>
|
||||
template <typename TensorType,
|
||||
typename BlockShapeTuple,
|
||||
typename BlockIdxs,
|
||||
typename ProjectionTuple>
|
||||
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
const BlockShapeTuple& tile_shape,
|
||||
const index_t block_id,
|
||||
const BlockIdxs& block_idxs,
|
||||
const ProjectionTuple& projection)
|
||||
{
|
||||
static_assert(!IsNestedTuple(BlockShapeTuple{}));
|
||||
|
||||
constexpr bool is_default_projection =
|
||||
is_same_v<ProjectionTuple, decltype(detail::GenerateDefaultProjection(BlockShapeTuple{}))>;
|
||||
static_assert(!IsNestedTuple(BlockIdxs{}));
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -274,49 +355,77 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
|
||||
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
|
||||
|
||||
// TODO: Enable block_2_tile_map partitioning for non-default projection.
|
||||
if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection)
|
||||
constexpr auto projected_tile_shape =
|
||||
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
|
||||
// Number of dims which are partitioned
|
||||
constexpr auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{});
|
||||
const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs);
|
||||
if constexpr(decltype(dims_to_partition)::Size() == I2)
|
||||
{
|
||||
// Optimized version for 2d tile shape [MxK]
|
||||
const auto shape_with_projection_dims =
|
||||
detail::CalculateShapeWithProjection(shape(tensor), projection);
|
||||
// Set Value for M, N partition
|
||||
const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0));
|
||||
const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1));
|
||||
constexpr auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0));
|
||||
constexpr auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1));
|
||||
auto m_n_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N));
|
||||
// Get 1D block id
|
||||
const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape);
|
||||
const auto block_lengths_desc = make_naive_tensor_descriptor_packed(grid_size);
|
||||
const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs);
|
||||
// Optimized version for 2d tile shape [MxN]
|
||||
const auto block_2_tile_map =
|
||||
BlockToCTileMap_M00_N0_M01Adapt<BlockShapeTuple{}.At(I0),
|
||||
BlockShapeTuple{}.At(I1),
|
||||
remove_cvref_t<decltype(aligned_desc)>>(aligned_desc);
|
||||
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock,
|
||||
NPerBlock,
|
||||
remove_cvref_t<decltype(m_n_desc)>>(m_n_desc);
|
||||
const auto block_work_idx =
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id_1d));
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape));
|
||||
const index_t k_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape));
|
||||
const auto offset_multi_idxs =
|
||||
make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid);
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
// Apply 0 for non partitioned dims
|
||||
const auto offset_multi_idxs = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == dims_to_partition.At(I0))
|
||||
{
|
||||
return m_block_data_idx_on_grid;
|
||||
}
|
||||
else if constexpr(i == dims_to_partition.At(I1))
|
||||
{
|
||||
return n_block_data_idx_on_grid;
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<0>{};
|
||||
}
|
||||
},
|
||||
Number<BlockShapeTuple::Size()>{});
|
||||
const auto projected_offset_multi_idxs =
|
||||
detail::ApplyProjection(offset_multi_idxs, projection);
|
||||
// Create new layout and tensor
|
||||
const auto tile_layout =
|
||||
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
|
||||
aligned_desc);
|
||||
Layout<remove_reference_t<decltype(projected_tile_shape)>, decltype(aligned_desc)>(
|
||||
projected_tile_shape, aligned_desc);
|
||||
auto tile_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
|
||||
// Apply offsets
|
||||
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
|
||||
tile_tensor.SetMultiIdxOffset(to_multi_index(projected_offset_multi_idxs));
|
||||
return tile_tensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Calculate offsets
|
||||
// Sequence with data to process per block
|
||||
constexpr auto projected_tile_shape =
|
||||
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
|
||||
using ProjectedTileShapeTuple = decltype(projected_tile_shape);
|
||||
constexpr auto projected_tile_shape_seq =
|
||||
generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
|
||||
Number<ProjectedTileShapeTuple::Size()>{});
|
||||
// Tuple with number of blocks
|
||||
const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection);
|
||||
const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
|
||||
const auto block_idxs =
|
||||
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
|
||||
const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection);
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
const auto projected_block_idxs =
|
||||
to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection));
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
|
||||
// Create new layout and tensor
|
||||
const auto tile_layout =
|
||||
@@ -338,52 +447,17 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param tile_shape Shapes of requested tile.
|
||||
* \param block_id Block index represented as integer.
|
||||
* \param block_idxs Tuple of block indexes represented as integer. If slice,
|
||||
* then get whole dim.
|
||||
* \return Tile tensor.
|
||||
*/
|
||||
template <typename TensorType, typename BlockShapeTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
|
||||
template <typename TensorType, typename BlockShapeTuple, typename BlockIdxs>
|
||||
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
const BlockShapeTuple& tile_shape,
|
||||
const BlockIdxs& block_idxs)
|
||||
{
|
||||
const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
|
||||
return make_local_tile(tensor, tile_shape, block_id, projection);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Pad tensor shapes to be adjusted to tile lengths.
|
||||
*
|
||||
*
|
||||
* \param tensor Tensor to pad.
|
||||
* \param tile_lengths Tile lengths to align tensor shape.
|
||||
* \return Padded tensor.
|
||||
*/
|
||||
template <typename TensorType, typename TileLengths>
|
||||
__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths)
|
||||
{
|
||||
const auto& tensor_shape = shape(tensor);
|
||||
using TensorShapeType = remove_reference_t<decltype(tensor_shape)>;
|
||||
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
|
||||
// Generate sequence with ones to mark that all dims will be padded
|
||||
constexpr auto do_pads_seq =
|
||||
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<TensorShapeType::Size()>{});
|
||||
// Create descriptor with padding
|
||||
auto padded_desc =
|
||||
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
|
||||
// Generate padded shape
|
||||
const auto padded_shape = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& dim = size<i>(tensor_shape);
|
||||
const auto& tile_length = size<i>(tile_lengths);
|
||||
return ck::math::integer_divide_ceil(dim, tile_length) * tile_length;
|
||||
},
|
||||
Number<TileLengths::Size()>{});
|
||||
// Create layout and tensor
|
||||
const auto padded_layout =
|
||||
Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
|
||||
auto partition_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), padded_layout);
|
||||
partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets());
|
||||
return partition_tensor;
|
||||
return make_local_tile(tensor, tile_shape, block_idxs, projection);
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
|
||||
Reference in New Issue
Block a user