mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
added GetLinearDimensionMask
This commit is contained in:
@@ -173,8 +173,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(in_e_n1_b_n2_global_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
Sequence<0, 1, 0, 1>,
|
||||
Sequence<1, 1, 1, 1>,
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
@@ -216,8 +214,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
Sequence<1, 1>,
|
||||
Sequence<1, 1>,
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
@@ -427,8 +423,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_k0_k1_n1_b_n2_thread_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_global_desc),
|
||||
Sequence<1, 1, 1, 1, 1>,
|
||||
Sequence<1, 1, 1, 0, 1>,
|
||||
decltype(
|
||||
out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
|
||||
@@ -145,8 +145,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(in_e_b_block_desc),
|
||||
Sequence<0, 0>,
|
||||
Sequence<1, 1>,
|
||||
decltype(in_e_b_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_B,
|
||||
InBlockCopyClusterLengths_E_B,
|
||||
@@ -186,8 +184,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
Sequence<1, 1>,
|
||||
Sequence<1, 1>,
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
@@ -390,8 +386,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(out_k0_k1_b0_b1_thread_desc),
|
||||
decltype(out_k0_k1_b0_b1_global_desc),
|
||||
Sequence<1, 1, 1, 1>,
|
||||
Sequence<1, 1, 0, 0>,
|
||||
decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 4, 1>::type,
|
||||
3,
|
||||
|
||||
@@ -325,14 +325,14 @@ struct TensorCoordinate
|
||||
private:
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor<Ts...>)
|
||||
{
|
||||
return NormalTensorCoordinate<ConstantTensorDescriptor<Ts...>>();
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor<Ts...>)
|
||||
{
|
||||
return MergedTensorCoordinate<ConstantMergedTensorDescriptor<Ts...>>();
|
||||
}
|
||||
|
||||
@@ -192,7 +192,7 @@ struct TensorCoordinate_v2
|
||||
private:
|
||||
template <typename... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
|
||||
{
|
||||
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>(
|
||||
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
|
||||
@@ -200,7 +200,7 @@ struct TensorCoordinate_v2
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
|
||||
{
|
||||
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>(
|
||||
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
|
||||
|
||||
@@ -353,7 +353,6 @@ struct TransformedTensorDescriptor
|
||||
return GetLowerTensorDescriptor().CalculateOffset(CalculateLowerIndex(idx_up));
|
||||
}
|
||||
|
||||
#if 1
|
||||
struct lambda_sequence_logical_and
|
||||
{
|
||||
template <typename... Seqs>
|
||||
@@ -378,7 +377,7 @@ struct TransformedTensorDescriptor
|
||||
// check only one transform at a time
|
||||
template <typename Transform, typename LowDimensionId, typename UpDimensionId>
|
||||
__host__ __device__ constexpr auto
|
||||
operator()(const Transform& tran, LowDimensionId, UpDimensionId) const
|
||||
operator()(Transform, LowDimensionId, UpDimensionId) const
|
||||
{
|
||||
// judge if transformation is linear
|
||||
constexpr bool is_linear_transform = Transform::IsLinearTransform();
|
||||
@@ -392,23 +391,42 @@ struct TransformedTensorDescriptor
|
||||
// create linear mask for upper dimensions
|
||||
constexpr bool are_up_dim_linear = is_linear_transform && are_all_low_dim_linear;
|
||||
|
||||
constexpr auto mask_of_up_linear_dims = modifiy_sequence_by_ids(
|
||||
typename uniform_sequence_gen<nDimUp, 0>::type{},
|
||||
typename uniform_sequence_gen<UpDimensionId::Size(), 1>::type{},
|
||||
constexpr auto mask_of_up_linear_dims = modify_sequence_elements_by_ids(
|
||||
typename uniform_sequence_gen<nDimUp, 1>::type{},
|
||||
typename uniform_sequence_gen<UpDimensionId::Size(), are_up_dim_linear>::type{},
|
||||
UpDimensionId{});
|
||||
|
||||
return mask_of_up_linear_dims;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: this is a hack, transform_tuples() doesn't compile, would complain about constexpr
|
||||
template <typename F, typename X, typename Y, typename Z, index_t... Is>
|
||||
__host__ __device__ static constexpr auto
|
||||
dummy_transform_tuples_impl(F f, X x, Y y, Z z, Sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLinearDimensionMask()
|
||||
{
|
||||
#if 0
|
||||
// create tuple of linear dimension masks, for all transformations
|
||||
constexpr auto tuple_of_linear_dimension_mask =
|
||||
transform_tuples(lambda_get_linear_dimension_mask_of_single_tranform{},
|
||||
Transforms{},
|
||||
LowDimensionIds{},
|
||||
UpDimensionIds{});
|
||||
#else
|
||||
// create tuple of linear dimension masks, for all transformations
|
||||
// TODO: this is a hack, transform_tuples() doesn't compile, complain about constexpr
|
||||
constexpr auto tuple_of_linear_dimension_mask = dummy_transform_tuples_impl(
|
||||
lambda_get_linear_dimension_mask_of_single_tranform{},
|
||||
Transforms{},
|
||||
LowDimensionIds{},
|
||||
UpDimensionIds{},
|
||||
typename arithmetic_sequence_gen<0, Transforms::Size(), 1>::type{});
|
||||
#endif
|
||||
|
||||
// reduce tuple of masks into one mask
|
||||
constexpr auto linear_dimension_mask =
|
||||
@@ -444,6 +462,7 @@ struct TransformedTensorDescriptor
|
||||
typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, nonlinear_dimension_mask);
|
||||
}
|
||||
|
||||
#if 0
|
||||
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
|
||||
{
|
||||
// not implemented
|
||||
|
||||
@@ -680,8 +680,6 @@ struct BlockwiseGenericTensorSliceCopy_v3
|
||||
template <index_t BlockSize,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcLinearDimensionMask,
|
||||
typename DstLinearDimensionMask,
|
||||
typename SliceLengths,
|
||||
typename SubLengths,
|
||||
typename ThreadClusterLengths,
|
||||
@@ -794,27 +792,21 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
private:
|
||||
using RegisterBufferDesc = decltype(make_native_tensor_descriptor_packed(SubLengths{}));
|
||||
|
||||
using ThreadwiseLoad =
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<SrcDesc,
|
||||
RegisterBufferDesc,
|
||||
SrcLinearDimensionMask,
|
||||
typename uniform_sequence_gen<nDim, 1>::type,
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v4r2<SrcDesc,
|
||||
RegisterBufferDesc,
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
|
||||
using ThreadwiseStore =
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<RegisterBufferDesc,
|
||||
DstDesc,
|
||||
typename uniform_sequence_gen<nDim, 1>::type,
|
||||
DstLinearDimensionMask,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<RegisterBufferDesc,
|
||||
DstDesc,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
|
||||
ThreadwiseLoad mThreadwiseLoad;
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
|
||||
@@ -1130,8 +1130,6 @@ struct ThreadwiseGenericTensorSliceCopy_v3r1
|
||||
// the other is device memory or LDS
|
||||
template <typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcLinearDimensionMask,
|
||||
typename DstLinearDimensionMask,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorAccessDim,
|
||||
@@ -1315,11 +1313,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
|
||||
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
|
||||
|
||||
// TODO:: stop using this hack, once TransformedTensorDescriptor::GetLinearDimensionMask()
|
||||
// is implemented
|
||||
constexpr auto src_linear_dim_mask = SrcLinearDimensionMask{};
|
||||
constexpr auto src_nonlinear_dim_mask =
|
||||
SrcLinearDimensionMask::Transform(logical_not<index_t>{});
|
||||
// separate linear dimensions from non-linear dimensions
|
||||
constexpr auto src_linear_dim_mask = SrcDesc::GetLinearDimensionMask();
|
||||
constexpr auto src_nonlinear_dim_mask = SrcDesc::GetNonLinearDimensionMask();
|
||||
|
||||
static_assert(
|
||||
src_linear_dim_mask.At(VectorAccessDim) || long_vector_size == SrcDataPerAccess,
|
||||
@@ -1459,11 +1455,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
|
||||
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
|
||||
|
||||
// TODO:: stop using this hack, once TransformedTensorDescriptor::GetLinearDimensionMask()
|
||||
// is implemented
|
||||
constexpr auto dst_linear_dim_mask = DstLinearDimensionMask{};
|
||||
constexpr auto dst_nonlinear_dim_mask =
|
||||
DstLinearDimensionMask::Transform(logical_not<index_t>{});
|
||||
// separate linear dimensions from non-linear dimensions
|
||||
constexpr auto dst_linear_dim_mask = DstDesc::GetLinearDimensionMask();
|
||||
constexpr auto dst_nonlinear_dim_mask = DstDesc::GetNonLinearDimensionMask();
|
||||
|
||||
static_assert(
|
||||
dst_linear_dim_mask.At(VectorAccessDim) || long_vector_size == DstDataPerAccess,
|
||||
|
||||
@@ -125,6 +125,13 @@ transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
|
||||
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z, index_t... Is>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
@@ -141,5 +148,12 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
|
||||
f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z>
|
||||
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user