mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
adding implicit GEMM v4r2
This commit is contained in:
@@ -11,7 +11,6 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = merge(N0, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
@@ -182,6 +181,12 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
|
||||
InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0, 0, 0, 0});
|
||||
|
||||
#if 1
|
||||
{
|
||||
printf("id (%d %d), in offset: %d %d\n", get_block_1d_id(), get_thread_local_1d_id(), blockwise_in_copy.mThreadSrcOffset, blockwise_in_copy.mThreadDstOffset);
|
||||
}
|
||||
#endif
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
|
||||
@@ -53,15 +53,15 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc,
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t N0 = 1;
|
||||
constexpr index_t Ho0 = 1;
|
||||
constexpr index_t Wo0 = 2;
|
||||
constexpr index_t N0 = 1;
|
||||
constexpr index_t Ho0 = 2;
|
||||
constexpr index_t Wo0 = 1;
|
||||
|
||||
constexpr index_t N2 = 1;
|
||||
constexpr index_t N2 = 4;
|
||||
constexpr index_t Ho2 = 1;
|
||||
constexpr index_t Wo2 = 4;
|
||||
constexpr index_t Wo2 = 1;
|
||||
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -79,8 +79,8 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<1, 1, 1, 1, 1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<8, 1, 1, 2, 16, 1, 1, 1>;
|
||||
using InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<1, 1, 1, 1, 1, 4, 1, 1>;
|
||||
using InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<8, 1, 2, 1, 16, 1, 1, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder =
|
||||
Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N0, N2, Ho0, Ho2, Wo0, B, Wo2]
|
||||
using InBlockCopySrcAccessOrder =
|
||||
@@ -88,7 +88,7 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc,
|
||||
using InBlockCopyDstAccessOrder =
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>; // [E, N0, Ho0, Wo0, B, N2, Ho2, Wo2]
|
||||
|
||||
constexpr index_t InBlockCopyDataPerAccess_W2 = 4;
|
||||
constexpr index_t InBlockCopyDataPerAccess_W2 = 1;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
|
||||
Reference in New Issue
Block a user