mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
@@ -7,6 +7,10 @@
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_4d_tensor_op.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -129,10 +133,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
#if 1
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerAccess_N>{};
|
||||
#else
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
@@ -146,11 +160,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
InBlockCopyDataPerAccess_N,
|
||||
InBlockCopyDataPerAccess_N>({0, 0, 0, 0},
|
||||
{0, 0, 0, 0});
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
|
||||
#else
|
||||
const auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
@@ -163,6 +187,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
1,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
|
||||
#endif
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
@@ -402,7 +427,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin);
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_10d_thread_desc),
|
||||
#if 1
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_thread_on_global,
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerAccess_N>{});
|
||||
#else
|
||||
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
@@ -413,6 +446,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#endif
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
@@ -460,7 +494,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin);
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_10d_thread_desc),
|
||||
#if 1
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_thread_on_global,
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerAccess_N>{});
|
||||
#else
|
||||
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
@@ -471,6 +513,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#endif
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -301,14 +301,14 @@ struct TensorCoordinate
|
||||
private:
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor<Ts...>)
|
||||
{
|
||||
return NormalTensorCoordinate<ConstantTensorDescriptor<Ts...>>();
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor<Ts...>)
|
||||
{
|
||||
return MergedTensorCoordinate<ConstantMergedTensorDescriptor<Ts...>>();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user