mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
removing dependency on old tensor descriptor
This commit is contained in:
@@ -6,8 +6,8 @@
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -115,16 +115,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc = transform_tensor_descriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<KBlockWork, BBlockWork>{}),
|
||||
make_tuple(Merge<Sequence<KBlockWork, BBlockWork>>{}),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id = block_work_desc.CalculateLowerIndex(get_block_1d_id());
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// global memory
|
||||
@@ -185,11 +182,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_k_e_global_desc =
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
|
||||
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
reorder_tensor_descriptor_given_upper2lower(wei_k_e_global_desc, Sequence<1, 0>{});
|
||||
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -103,13 +103,12 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// global mem
|
||||
@@ -157,21 +156,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
InBlockCopyDataPerAccess_B>(
|
||||
{0, b_block_data_on_global}, {0, 0});
|
||||
|
||||
// weight tensor
|
||||
// global mem
|
||||
#if 0
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
transform_tensor_descriptor(wei_k_c_y_x_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}),
|
||||
make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
#else // hack
|
||||
constexpr auto wei_e_k_global_desc_old =
|
||||
WeiGlobalDesc::Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
|
||||
constexpr auto wei_e_k_global_desc = make_native_tensor_descriptor(
|
||||
wei_e_k_global_desc_old.GetLengths(), wei_e_k_global_desc_old.GetStrides());
|
||||
#endif
|
||||
// weight tensor
|
||||
// global mem
|
||||
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
|
||||
|
||||
// LDS
|
||||
// be careful of LDS alignment
|
||||
@@ -267,9 +255,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.template Run<Float, address_space_t::global, address_space_t::lds>(
|
||||
blockwise_in_copy.template Run<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_block_double);
|
||||
blockwise_wei_copy.template Run<Float, address_space_t::global, address_space_t::lds>(
|
||||
blockwise_wei_copy.template Run<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
@@ -292,8 +280,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
@@ -301,26 +289,26 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
|
||||
p_in_global, p_in_register_buffer);
|
||||
blockwise_wei_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_register_buffer);
|
||||
blockwise_in_copy
|
||||
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy
|
||||
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.template RunStoreRegisterBuffer<Float, address_space_t::lds>(
|
||||
p_in_register_buffer, p_in_block_next);
|
||||
blockwise_wei_copy.template RunStoreRegisterBuffer<Float, address_space_t::lds>(
|
||||
p_wei_register_buffer, p_wei_block_next);
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
// even iteration
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
@@ -329,19 +317,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
|
||||
p_in_global, p_in_register_buffer);
|
||||
blockwise_wei_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_register_buffer);
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.template RunStoreRegisterBuffer<Float, address_space_t::lds>(
|
||||
p_in_register_buffer, p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.template RunStoreRegisterBuffer<Float, address_space_t::lds>(
|
||||
p_wei_register_buffer, p_wei_block_double + wei_block_space);
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
@@ -402,10 +390,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1})
|
||||
#if 1
|
||||
.template Run_generic<Float, address_space_t::generic, address_space_t::global>
|
||||
#elif 1
|
||||
.template Run_generic<Float,
|
||||
Float,
|
||||
address_space_t::generic,
|
||||
address_space_t::global>
|
||||
#else // tweaking
|
||||
.template Run_optimized_dst_address_calculation<Float,
|
||||
address_space_t::vgpr,
|
||||
Float,
|
||||
address_space_t::generic,
|
||||
address_space_t::global>
|
||||
#endif
|
||||
(p_out_thread, p_out_global);
|
||||
|
||||
@@ -132,7 +132,7 @@ struct Merge
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperLengths()
|
||||
{
|
||||
return Sequence<accumulate_on_sequence(
|
||||
return Sequence<reduce_on_sequence(
|
||||
LowerLengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
|
||||
}
|
||||
|
||||
|
||||
@@ -149,6 +149,7 @@ __host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescript
|
||||
}
|
||||
|
||||
#if 0
|
||||
// not implemented
|
||||
template <typename LowerTensorDescriptor,
|
||||
typename PadDimensionIds,
|
||||
typename LeftPads,
|
||||
@@ -171,6 +172,42 @@ __host__ __device__ constexpr auto
|
||||
}
|
||||
#endif
|
||||
|
||||
// a cluster map 1d index to N-d index
|
||||
template <typename Lengths, typename ArrangeOrder>
|
||||
struct ClusterDescriptor
|
||||
{
|
||||
static constexpr index_t nDim = Lengths::Size();
|
||||
|
||||
static constexpr auto mDesc = transform_tensor_descriptor(
|
||||
make_native_tensor_descriptor_packed(Lengths{}),
|
||||
make_tuple(Merge<decltype(Lengths::ReorderGivenNew2Old(ArrangeOrder{}))>{}),
|
||||
make_tuple(ArrangeOrder{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
__host__ __device__ constexpr ClusterDescriptor()
|
||||
{
|
||||
static_assert(Lengths::Size() == nDim && ArrangeOrder::Size() == nDim,
|
||||
"wrong! size not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<ArrangeOrder>{}, "wrong! ArrangeOrder is wrong");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize() { return mDesc.GetElementSize(); }
|
||||
|
||||
__host__ __device__ static constexpr auto CalculateClusterIndex(index_t idx_1d)
|
||||
{
|
||||
return mDesc.CalculateLowerIndex(MultiIndex<1>{idx_1d});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Lengths,
|
||||
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
|
||||
__host__ __device__ constexpr auto make_cluster_descriptor(
|
||||
Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
|
||||
{
|
||||
return ClusterDescriptor<Lengths, ArrangeOrder>{};
|
||||
}
|
||||
|
||||
template <typename... NativeDimensions>
|
||||
__host__ __device__ void
|
||||
print_tensor_descriptor(const char* s, const NativeTensorDescriptor<NativeDimensions...>& desc)
|
||||
|
||||
@@ -2,13 +2,10 @@
|
||||
#define CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_coordinate.hpp"
|
||||
#include "tensor_view.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "tensor_coordinate_v2.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
|
||||
@@ -16,6 +13,8 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if 0
|
||||
|
||||
// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
|
||||
// memory layout (ordering of dimensions) can be different between src and dst.
|
||||
// This functions assume each thread is reading and writing a normal (not merged) tensor,
|
||||
@@ -677,6 +676,8 @@ struct BlockwiseGenericTensorSliceCopy_v3
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename BlockSrcDesc,
|
||||
typename BlockDstDesc,
|
||||
@@ -710,42 +711,17 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
#if 1
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
|
||||
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
#else
|
||||
constexpr auto thread_cluster_lengths_in_arrange_order =
|
||||
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{});
|
||||
|
||||
constexpr auto thread_cluster_desc = transform_tensor_descriptor(
|
||||
make_native_tensor_descriptor_packed(thread_cluster_lengths_in_arrange_order),
|
||||
make_tuple(Merge<decltype(thread_cluster_lengths_in_arrange_order)>{}),
|
||||
make_tuple(arithmetic)
|
||||
|
||||
::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
|
||||
"wrong! BlockSize not consistent with ThreadClusterLengths");
|
||||
|
||||
constexpr auto thread_cluster_id = transform_tensor_descriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<KBlockWork, BBlockWork>{}),
|
||||
make_tuple(Merge<Sequence<KBlockWork, BBlockWork>>{}),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto block_work_multi_id = block_work_desc.CalculateLowerIndex(get_block_1d_id());
|
||||
#endif
|
||||
// map threads to cluster
|
||||
constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
|
||||
"wrong! BlockSize not consistent with ThreadClusterLengths");
|
||||
|
||||
const auto thread_cluster_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
thread_cluster_desc.CalculateClusterIndex(get_thread_local_1d_id());
|
||||
|
||||
const auto data_cluster_id =
|
||||
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto thread_data_id_begin = data_cluster_id * ThreadSliceLengths{};
|
||||
const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{};
|
||||
|
||||
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
|
||||
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
|
||||
|
||||
@@ -2,11 +2,8 @@
|
||||
#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_coordinate.hpp"
|
||||
#include "tensor_view.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "tensor_coordinate_v2.hpp"
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
|
||||
@@ -23,6 +20,8 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if 0
|
||||
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the dimensions of vector access to be different on src and dst.
|
||||
// It also allows the vector size to be different on src and dst.
|
||||
@@ -1121,6 +1120,8 @@ struct ThreadwiseGenericTensorSliceCopy_v3r1
|
||||
DstSlice mDstSlice;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
// This version use multi-index transformation
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the vector size to be different on src and dst.
|
||||
|
||||
@@ -473,6 +473,13 @@ struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
|
||||
using sorted_ids = Sequence<Id>;
|
||||
};
|
||||
|
||||
template <typename Compare>
|
||||
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
|
||||
{
|
||||
using sorted_values = Sequence<>;
|
||||
using sorted_ids = Sequence<>;
|
||||
};
|
||||
|
||||
template <typename Values, typename Compare>
|
||||
struct sequence_sort
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user