diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp index e19bc5093e..728b990418 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp @@ -2,16 +2,27 @@ #define CK_BLOCKWISE_GEMM_V2_HPP #include "common_header.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_gemm_v2.hpp" namespace ck { -// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N] +// C[M, N] += transpose(A[K, M]) * B[K, N] // A and B are visable to the whole block, C is distributed among each thread -// If following number are power of 2, index calculation shall be greatly reduced: -// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster, -// MLevel1ThreadCluster, NLevel1ThreadCluster +// Assume: +// 1. A: +// 1. BlockMatrixA is known at compile-time +// 2. ABlockBuffer is DynamicBuffer +// 2. B: +// 1. BlockMatrixA is known at compile-time +// 2. BBlockBuffer is DynamicBuffer +// 3. C: +// 1. ThreadMatrixC is known at compile-time +// 2. CThreadBuffer is StaticBuffer template -struct BlockwiseGemm_km_kn_m0m1n0n1_v1 + index_t ThreadGemmBDataPerRead_N, + typename std::enable_if::type = false> +struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 { struct MatrixIndex { @@ -32,10 +47,49 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 index_t col; }; - index_t mMyThreadOffsetA; - index_t mMyThreadOffsetB; + private: + static constexpr auto a_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, ThreadMatrixC{}.GetLength(Number<0>{}))); - __device__ BlockwiseGemm_km_kn_m0m1n0n1_v1() + static constexpr auto b_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, ThreadMatrixC{}.GetLength(Number<1>{}))); + + using AThreadCopy = + ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1>, + 1, + ThreadGemmADataPerRead_M, + AddressSpace::Generic, + AddressSpace::Vgpr, + 1>; + + using BThreadCopy = + ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1>, + 1, + ThreadGemmBDataPerRead_N, + AddressSpace::Generic, + AddressSpace::Vgpr, + 1>; + + MatrixIndex c_thread_begin_mtx_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; + + public: + __device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1() + : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, + a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.row)}, + b_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.col)} { static_assert(BlockMatrixA::IsKnownAtCompileTime() && BlockMatrixB::IsKnownAtCompileTime() && @@ -51,23 +105,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), - "wrong! K dimension not consistent\n"); + "wrong! K dimension not consistent"); constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed constexpr index_t N = BlockMatrixB{}.GetLength(I1); static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 && N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0, - "wrong! Cannot evenly divide work among\n"); + "wrong! Cannot evenly divide work among"); static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] && ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1], "wrong! ThreadMatrixC lengths is wrong"); - - auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.row)); - mMyThreadOffsetB = BlockMatrixB{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.col)); } __device__ static constexpr auto GetThreadMatrixCLengths() @@ -104,103 +153,30 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; } - template - __device__ void - Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const + template + __device__ void Run_pipelined_2x2(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const { + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx_desc = ThreadMatrixC{}; constexpr auto K = a_block_mtx.GetLength(I0); - constexpr auto MPerThread = c_thread_mtx.GetLength(I0); - constexpr auto NPerThread = c_thread_mtx.GetLength(I1); - - constexpr index_t MPerLevel1Cluster = - MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; - constexpr index_t NPerLevel1Cluster = - NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster; - - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - - // thread A, B for GEMM - constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( - Number{}, Number{}); - - constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( - Number{}, Number{}); - - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - - constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2{}; - - constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v2{}; - - constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1{}; -#pragma unroll - // loop over k - for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) - { -#pragma unroll - // read A - for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - a_thread_copy.Run(p_a_block + - a_block_mtx.CalculateOffset( - make_tuple(k_begin, m_repeat * MPerLevel1Cluster)) + - mMyThreadOffsetA, - p_a_thread + a_thread_mtx.CalculateOffset( - make_tuple(0, m_repeat * MPerThreadSubC))); - } - -#pragma unroll - // read B - for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - b_thread_copy.Run(p_b_block + - b_block_mtx.CalculateOffset( - make_tuple(k_begin, n_repeat * NPerLevel1Cluster)) + - mMyThreadOffsetB, - p_b_thread + b_thread_mtx.CalculateOffset( - make_tuple(0, n_repeat * NPerThreadSubC))); - } - - // C += A * B - threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); - } - } - - template - __device__ void - Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr auto K = a_block_mtx.GetLength(I0); - - constexpr auto MPerThread = c_thread_mtx.GetLength(I0); - constexpr auto NPerThread = c_thread_mtx.GetLength(I1); + constexpr auto MPerThread = c_thread_mtx_desc.GetLength(I0); + constexpr auto NPerThread = c_thread_mtx_desc.GetLength(I1); constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; @@ -211,15 +187,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - static_assert(MRepeat == 2 && NRepeat == 2, - "wrong! inline asm cannot deal with this GEMM config yet"); - - // thread A, B - constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{})); - - constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{})); + static_assert(MRepeat == 2 && NRepeat == 2, "wrong! only support 2x2 pipeline"); // thread A-sub, B-sub constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2( @@ -234,113 +202,152 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 make_tuple(Number{}, Number{}), make_tuple(Number{}, Number<1>{})); - FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()]; + auto a_thread_buf = make_static_buffer(a_thread_mtx_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer(b_thread_mtx_desc_.GetElementSpaceSize()); - constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2{}; - - constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v2{}; - - constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1{}; - - const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA; - const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB; + constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1{}; // read A_sub_0 - a_thread_copy.Run(p_a_block_off, p_a_thread); + a_thread_copy_.Run(BlockMatrixA{}, + make_tuple(I0, I0), + a_block_buf, + a_thread_mtx_desc_, + make_tuple(I0, I0), + a_thread_buf); // read B_sub_0 - b_thread_copy.Run(p_b_block_off, p_b_thread); + b_thread_copy_.Run(BlockMatrixB{}, + make_tuple(I0, I0), + b_block_buf, + b_thread_mtx_desc_, + make_tuple(I0, I0), + b_thread_buf); // read B_sub_1 - b_thread_copy.Run(p_b_block_off + - b_block_mtx.CalculateOffset(make_tuple(0, NPerLevel1Cluster)), - p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC))); + b_thread_copy_.Run(BlockMatrixB{}, + make_tuple(I0, Number{}), + b_block_buf, + b_thread_mtx_desc_, + make_tuple(I0, Number{}), + b_thread_buf); // read A_sub_1 - a_thread_copy.Run(p_a_block_off + - a_block_mtx.CalculateOffset(make_tuple(0, MPerLevel1Cluster)), - p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC))); + a_thread_copy_.Run(BlockMatrixA{}, + make_tuple(I0, Number{}), + a_block_buf, + a_thread_mtx_desc_, + make_tuple(I0, Number{}), + a_thread_buf); // C_sub_00 += transpose(A_sub_0) * B_sub_0 - threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0), + b_thread_buf, + make_tuple(I0, I0), + c_thread_buf, + make_tuple(I0, I0)); // C_sub_01 += transpose(A_sub_0) * B_sub_1 - threadwise_gemm.Run( - p_a_thread, - p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)), - p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC))); + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0), + b_thread_buf, + make_tuple(I0, Number{}), + c_thread_buf, + make_tuple(I0, Number{})); -#pragma unroll // loop over rest of k - for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop) - { + static_for{}([&](auto k) { // read A_sub_0 - a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, 0)), - p_a_thread); + a_thread_copy_.Run(BlockMatrixA{}, + make_tuple(k, I0), + a_block_buf, + a_thread_mtx_desc_, + make_tuple(I0, I0), + a_thread_buf); // C_sub_10 += transpose(A_sub_1) * B_sub_0 - threadwise_gemm.Run( - p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)), - p_b_thread, - p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0))); + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, Number{}), + b_thread_buf, + make_tuple(I0, I0), + c_thread_buf, + make_tuple(Number{}, I0)); // read B_sub_0 - b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, 0)), - p_b_thread); + b_thread_copy_.Run(BlockMatrixB{}, + make_tuple(k, I0), + b_block_buf, + b_thread_mtx_desc_, + make_tuple(I0, I0), + b_thread_buf); // C_sub_11 += transpose(A_sub_1) * B_sub_1 - threadwise_gemm.Run( - p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)), - p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)), - p_c_thread + - c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC))); + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, Number{}), + b_thread_buf, + make_tuple(I0, Number{}), + c_thread_buf, + make_tuple(Number{}, Number{})); // read B_sub_1 - b_thread_copy.Run( - p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, NPerLevel1Cluster)), - p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC))); + b_thread_copy_.Run(BlockMatrixB{}, + make_tuple(k, Number{}), + b_block_buf, + b_thread_mtx_desc_, + make_tuple(I0, Number{}), + b_thread_buf); // read A_sub_1 - a_thread_copy.Run( - p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, MPerLevel1Cluster)), - p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC))); + a_thread_copy_.Run(BlockMatrixA{}, + make_tuple(k, Number{}), + a_block_buf, + a_thread_mtx_desc_, + make_tuple(I0, Number{}), + a_thread_buf); // C_sub_00 += transpose(A_sub_0) * B_sub_0 - threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0), + b_thread_buf, + make_tuple(I0, I0), + c_thread_buf, + make_tuple(I0, I0)); // C_sub_01 += transpose(A_sub_0) * B_sub_1 - threadwise_gemm.Run( - p_a_thread, - p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)), - p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC))); - } + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0), + b_thread_buf, + make_tuple(I0, Number{}), + c_thread_buf, + make_tuple(I0, Number{})); + }); // C_sub_10 += transpose(A_sub_1) * B_sub_0 - threadwise_gemm.Run( - p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)), - p_b_thread, - p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0))); + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, Number{}), + b_thread_buf, + make_tuple(I0, I0), + c_thread_buf, + make_tuple(Number{}, I0)); // C_sub_11 += transpose(A_sub_1) * B_sub_1 - threadwise_gemm.Run( - p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)), - p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)), - p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC))); + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, Number{}), + b_thread_buf, + make_tuple(I0, Number{}), + c_thread_buf, + make_tuple(Number{}, Number{})); } - template - __device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const { #if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE constexpr auto I0 = Number<0>{}; @@ -354,17 +361,16 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 if constexpr(MRepeat == 2 && NRepeat == 2) { - Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread); + Run_pipelined_2x2(a_block_buf, b_block_buf, c_thread_buf); } else { - Run_naive(p_a_block, p_b_block, p_c_thread); + Run_naive(a_block_buf, b_block_buf, c_thread_buf); } #else - Run_naive(p_a_block, p_b_block, p_c_thread); + Run_naive(a_block_buf, b_block_buf, c_thread_buf); #endif } }; - } // namespace ck #endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp index 7da08d6ef4..0048e396f7 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp @@ -6,12 +6,10 @@ namespace ck { -// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N] -// A and B are visable to the whole block, C is distributed among each thread -// If following number are power of 2, index calculation shall be greatly reduced: -// KPerThread, HPerThread, MLevel0ThreadCluster, NLevel0ThreadCluster, -// MLevel1ThreadCluster, NLevel1ThreadCluster template {}, Number{})); + + static constexpr auto b_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + static constexpr auto c_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + using AThreadCopy = + ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1>, + 1, + ThreadGemmADataPerRead_K, + AddressSpace::Generic, + AddressSpace::Vgpr, + 1>; __device__ BlockwiseGemm_km_kn_m0m1n0n1_v3() + : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, + a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)} { static_assert(BlockMatrixA::IsKnownAtCompileTime() && BlockMatrixB::IsKnownAtCompileTime() && @@ -61,11 +84,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster, "wrong! wrong blocksize\n"); - - auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - mMyThreadOffsetA = - BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.k * KPerThread)); } __device__ static constexpr auto GetThreadMatrixCLengths() @@ -91,37 +109,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 return MatrixIndex{k_thread_id, h_thread_id, w_thread_id}; } - template - struct ThreadwiseSliceCopy_a - { - template - __device__ static void Run(const Data* p_src, Data* p_dst) - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - using vector_t = typename vector_type_maker::type::type; - - static_for<0, NSliceRow, 1>{}([&](auto i) { - static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) { - constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j)); - constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j)); - - *reinterpret_cast(&p_dst[dst_offset]) = - *reinterpret_cast(&p_src[src_offset]); - }); - }); - } - }; - - template - __device__ void - Run_naive(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BThreadBuffer& b_thread_buf, + CThreadBuffer& c_thread_buf) const { + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -132,8 +131,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 constexpr auto EPerBlock = a_block_mtx.GetLength(I0); - constexpr auto KPerThreadSubC = 4; - + // HACK: fix this @Jing Zhang constexpr auto HoPerThreadSubC = 2; constexpr auto WoPerThreadSubC = 2; @@ -141,63 +139,53 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 static_assert(HPerThread % HoPerThreadSubC == 0, ""); static_assert(WPerThread % WoPerThreadSubC == 0, ""); - // thread A, B for GEMM - constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{})); + // thread A buffer for GEMM + StaticBuffer a_thread_buf; - constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); - - constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); - - FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()]; - - constexpr auto a_thread_copy = ThreadwiseSliceCopy_a{}; - - constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3{}; - // loop over k -#pragma unroll - for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop) - { -#pragma unroll - for(index_t k_begin = 0; k_begin < KPerThread; k_begin += KPerThreadSubC) - { - a_thread_copy.Run(p_a_block + - a_block_mtx.CalculateOffset(make_tuple(e_begin, k_begin)) + - mMyThreadOffsetA, - p_a_thread); -#pragma unroll - for(index_t h_begin = 0; h_begin < HPerThread; h_begin += HoPerThreadSubC) - { -#pragma unroll - for(index_t w_begin = 0; w_begin < WPerThread; w_begin += WoPerThreadSubC) - { - threadwise_gemm.Run(p_a_thread, - p_b_thread + b_thread_mtx.CalculateOffset(make_tuple( - e_begin, 0, h_begin, w_begin)), - p_c_thread + c_thread_mtx.CalculateOffset(make_tuple( - k_begin, 0, h_begin, w_begin))); - } - } - } - } + static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) { + static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) { + + a_thread_copy_.Run(a_block_mtx, + make_tuple(e_begin, k_begin), + a_block_buf, + a_thread_mtx_, + make_tuple(I0, I0), + a_thread_buf); + + static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) { + static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) { + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0), + b_thread_buf, + make_tuple(e_begin, I0, h_begin, w_begin), + c_thread_buf, + make_tuple(k_begin, I0, h_begin, w_begin)); + }); + }); + }); + }); } - template - __device__ void Run(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const + template + __device__ void MoveASliceWindow(const BlockMatrixA&, + const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx) { - Run_naive(p_a_block, p_b_thread, p_c_thread); + a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx); } + + private: + MatrixIndex c_thread_begin_mtx_idx_; + + AThreadCopy a_thread_copy_; }; } // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp index 15df1d23f4..0f94f67bbc 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp @@ -5,9 +5,10 @@ #include "dynamic_multi_index_transform_helper.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_v2.hpp" #include "blockwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp" -#include "blockwise_gemm_v2.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" namespace ck { @@ -256,19 +257,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 make_tuple(Number{}, Number{})); const auto blockwise_gemm = - BlockwiseGemm_km_kn_m0m1n0n1_v1{}; + BlockwiseGemm_km_kn_m0m1n0n1_v1r1{}; // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size = @@ -281,10 +285,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; // register allocation for output - FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()]; + auto c_thread_buf = + make_static_buffer(c_m0m1_n0n1_thread_desc.GetElementSpaceSize()); - // zero out threadwise output - threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread); + ThreadwiseDynamicTensorSliceSet_v1>{} + .Run(c_m0m1_n0n1_thread_desc, make_tuple(I0, I0), c_thread_buf, FloatAcc{0}); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); @@ -300,6 +307,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 constexpr auto b_k_n_global_move_slice_window_iterator_hack = BGlobalMoveSliceWindowIteratorHacks{}; + FloatAB* p_a_block_even = p_a_block_double; + FloatAB* p_b_block_even = p_b_block_double; + + FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size; + FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size; + + auto a_block_even_buf = make_dynamic_buffer(p_a_block_even); + auto b_block_even_buf = make_dynamic_buffer(p_b_block_even); + + auto a_block_odd_buf = make_dynamic_buffer(p_a_block_odd); + auto b_block_odd_buf = make_dynamic_buffer(p_b_block_odd); + // LDS double buffer: preload data into LDS { a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); @@ -311,12 +330,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 if constexpr(HasMainKBlockLoop) { - FloatAB* p_a_block_even = p_a_block_double; - FloatAB* p_b_block_even = p_b_block_double; - - FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size; - FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size; - index_t k_block_data_begin = 0; // LDS double buffer: main body @@ -340,7 +353,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); + blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); // LDS double buffer: store next data to LDS a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd); @@ -363,7 +376,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); + blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); // LDS double buffer: store next data to LDS a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even); @@ -390,7 +403,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); + blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); // LDS double buffer: store last data to LDS a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); @@ -399,16 +412,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 __syncthreads(); // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_a_block_double + a_block_space_size, - p_b_block_double + b_block_space_size, - p_c_thread); + blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); } else // if has 1 iteration left { __syncthreads(); // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); + blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); } // output: register to global memory @@ -461,7 +472,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 n_thread_data_on_global % N1)) .Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), - p_c_thread, + c_thread_buf, c_m0_m1_n0_n1_global_desc, p_c_global, c_m0_m1_n0_n1_global_tensor_iterator_hacks); diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp index 81a3a0674f..3be66f61d3 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp @@ -145,17 +145,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( Number{}, Number<1>{}, Number{}, Number{})); - const auto blockwise_gemm = - BlockwiseGemm_km_kn_m0m1n0n1_v3{}; + auto blockwise_gemm = BlockwiseGemm_km_kn_m0m1n0n1_v3{}; auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); @@ -223,11 +225,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 FloatAB* p_a_block = p_shared_block; - // register allocation for output - FloatAcc p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()]; + auto a_block_buf = make_dynamic_buffer(p_a_block); - // zero out threadwise output - threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread); + // register allocation for output + StaticBuffer c_thread_buf; + + // initialize output thread tensor + ThreadwiseDynamicTensorSliceSet_v1>{} + .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); @@ -242,12 +249,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = BGlobalMoveSliceWindowIteratorHacks{}; - constexpr auto b_thread_space_size = b_e_n_ho_wo_thread_desc.GetElementSpaceSize(); - FloatAB p_b_thread[b_thread_space_size * 2]; + // double regsiter buffer for b + StaticBuffer b_thread_even_buf, + b_thread_odd_buf; - FloatAB* p_b_thread_double = p_b_thread; - - // LDS double buffer: preload data into LDS + // LDS double buffer: preload data { a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_e_k_global_iterator_hacks); @@ -255,7 +261,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 p_b_global, b_e_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), - p_b_thread_double, + b_thread_even_buf, b_e_n_ho_wo_global_iterator_hacks); a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block); @@ -263,13 +269,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 __syncthreads(); - index_t b_block_data_begin = 0; - -#if 1 if constexpr(HasMainKBlockLoop) { - FloatAB* p_b_thread_even = p_b_thread_double; - FloatAB* p_b_thread_odd = p_b_thread_double + b_thread_space_size; + index_t e_block_data_begin = 0; // LDS double buffer: main body // use Do-While loop instead of For loop to simplify control flow @@ -283,16 +285,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 p_b_global, b_e_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), - p_b_thread_odd, + b_thread_odd_buf, b_e_n_ho_wo_global_iterator_hacks); // LDS double buffer: GEMM on current data - blockwise_gemm.Run( - p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), - p_b_thread_even, - p_c_thread); + // TODO: @Zhang Jing: blockwise gemm should be able to move slice window + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - b_block_data_begin += EPerBlock; + blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, b_thread_slice_copy_step); @@ -301,18 +301,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 p_b_global, b_e_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), - p_b_thread_even, + b_thread_even_buf, b_e_n_ho_wo_global_iterator_hacks); // LDS double buffer: GEMM on current data - blockwise_gemm.Run( - p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), - p_b_thread_odd, - p_c_thread); + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - b_block_data_begin += EPerBlock; + blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); - } while(b_block_data_begin < E - 2 * EPerBlock); + e_block_data_begin += 2 * EPerBlock; + + } while(e_block_data_begin < E - 2 * EPerBlock); } // LDS double buffer: tail @@ -325,34 +324,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 p_b_global, b_e_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), - p_b_thread_double + b_thread_space_size, + b_thread_odd_buf, b_e_n_ho_wo_global_iterator_hacks); // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run( - p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), - p_b_thread_double, - p_c_thread); + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - b_block_data_begin += EPerBlock; + blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), - p_b_thread_double + b_thread_space_size, - p_c_thread); + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); } else // if has 1 iteration left { // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), - p_b_thread_double, - p_c_thread); + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); } -#endif -#if 1 // output: register to global memory { // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor @@ -380,12 +368,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global)) .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), - p_c_thread, + c_thread_buf, c_k_n_ho_wo_global_desc, p_c_global, c_k_n_ho_wo_global_tensor_iterator_hacks); } -#endif } // pass tensor descriptor by reference diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp new file mode 100644 index 0000000000..f1b632aa84 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp @@ -0,0 +1,59 @@ +#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP +#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// Assume: +// 1. Desc is known at compile-time +// 2. Buffer is StaticBuffer +// 3. OriginIdx is known at compile-time +// 4. use #-iterator +template ::type = false> +struct ThreadwiseDynamicTensorSliceSet_v1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + template + __device__ void Run(const Desc&, const OriginIdx&, Buffer& buf, const Data& initial_value) const + { + static_assert(Desc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>>::value, + "wrong! OriginIdx need to be known at compile-time"); + + // Desc is known at compile-time + constexpr auto desc = remove_cv_t>{}; + + // OriginIdx is known at compile-time + constexpr auto origin_idx = to_multi_index(OriginIdx{}); + + static_ford{}([&](auto access_idx) { + constexpr auto coord = make_dynamic_tensor_coordinate(desc, origin_idx + access_idx); + + constexpr bool is_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord); + + constexpr index_t offset = coord.GetOffset(); + + if constexpr(is_valid) + { + buf(Number{}) = initial_value; + } + }); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp index 4f9ecd8b54..34b6cfec79 100644 --- a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp @@ -7,6 +7,15 @@ namespace ck { +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +namespace detail { // TODO: How to fix this? It uses an struct instead of lambda because lambda // doesn't have constructor template @@ -26,12 +35,17 @@ struct lambda_scalar_step_in_vector return (i == VectorDim) ? 1 : 0; } }; +} // namespace detail -// this version is less likely to have scratch memory issue, due to: -// 1. It does not keep reference to tensor descriptor -// 2. It does not construct new tensor coordinate for this->Run() -// Assume src_slice_origin_idx is 0 -// TODO: support non-zero src_slice_oring_idx +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is StaticBuffer +// 3. SrcSliceOrginIdx is known at compile-time +// 2. dst: +// 1. DstDesc is not known at compile-time +// 2. DstBuffer is DynamicBuffer +// 3. DstSliceOrginIdx is not known at compile time template + template __device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, - const SrcData* p_src, + const SrcBuffer& src_buf, const DstDesc& dst_desc, DstData* p_dst, const DstIteratorHacks& dst_iterator_hacks) @@ -84,9 +98,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 is_known_at_compile_time>>::value, "wrong! SrcSliceOrigin need to known at compile-time"); + static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); + + static_assert(is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer data type is wrong"); + // SrcDesc and src_slice_origin_idx are known at compile-time constexpr auto src_desc = remove_cv_t>{}; - constexpr auto src_slice_origin_idx = SrcSliceOriginIdx{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -94,10 +114,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( - lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_scalar_step_in_vector = - generate_sequence(lambda_scalar_step_in_vector{}, Number{}); + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -178,12 +198,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 typename vector_type_maker::type::type; static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - constexpr index_t src_offset = - src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx + - i * dst_scalar_step_in_vector); + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); dst_vector.template AsType()(i) = - type_convert{}(p_src[Number{}]); + type_convert{}(src_buf[Number{}]); }); const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( @@ -284,7 +303,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( - lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -359,10 +378,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 DstCoord dst_slice_origin_coord_; }; // namespace ck -// this version is less likely to have scratch memory issue, due to: -// 1. It does not keep reference to tensor descriptor -// 2. It does not construct new tensor coordinate for this->Run() -// Assume dst_slice_origin_idx is 0 +// Assume: +// 1. src_desc is not known at compile-time +// 2. dst_desc is known at compile-time +// 3. src_slice_origin_idx is not known at compile-time +// 4. dst_slice_origin_idx is known at compile-time and it's 0 template + template __device__ void Run(const SrcDesc& src_desc, const SrcData* p_src, const DstDesc&, const DstSliceOriginIdx&, - DstData* p_dst, + DstBuffer& dst_buf, const SrcIteratorHacks& src_iterator_hacks) { static_assert(DstDesc::IsKnownAtCompileTime(), @@ -414,6 +434,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 is_known_at_compile_time>>::value, "wrong! DstSliceOrigin need to known at compile-time"); + static_assert(is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + // DstDesc and dst_slice_origin_idx are known at compile-time constexpr auto dst_desc = remove_cv_t>{}; constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; @@ -424,10 +448,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( - lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_scalar_step_in_vector = - generate_sequence(lambda_scalar_step_in_vector{}, Number{}); + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; @@ -541,7 +565,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + i * src_scalar_step_in_vector); - p_dst[Number{}] = src_vector.template AsType()[i]; + dst_buf(Number{}) = src_vector.template AsType()[i]; }); constexpr auto move_on_dim = [&]() constexpr @@ -590,7 +614,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 } } - __device__ void Run(const SrcDesc& src_desc, const SrcData* p_src, DstData* p_dst) + template + __device__ void Run(const SrcDesc& src_desc, + const SrcData* p_src, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) { constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); @@ -600,7 +629,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - Run(src_desc, p_src, p_dst, src_iterator_hacks); + Run(src_desc, p_src, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks); } __device__ static constexpr auto GetSrcCoordinateResetStep() @@ -610,7 +639,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( - lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; @@ -685,12 +714,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 SrcCoord src_slice_origin_coord_; }; // namespace ck -// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory -// and sometimes useless instructions -// 1. It does not keep reference to tensor descriptor -// 2. It does not construct new tensor coordinate for this->Run() -// 3. It does not use pointer for VGPR thread buffer -// 4. It calculate offset for thread buffer directly, instead of moving the coordinate +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. src_slice_origin and dst_slice_origin are not known at compile-time, +// 3. Use thread buffer template ::value, + "wrong! current implementation assume SrcData and DstData are same type"); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -760,10 +791,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( - lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_scalar_step_in_vector = - generate_sequence(lambda_scalar_step_in_vector{}, Number{}); + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -838,11 +869,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 return src_data_idx; }(); - // copy data - typename vector_type_maker::type src_vector; + // copy data from src_buf to src_tmp_vector + vector_type_maker_t src_tmp_vector; - using src_vector_t = - typename vector_type_maker::type::type; + using src_vector_t = typename decltype(src_tmp_vector)::type; const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( src_desc, src_slice_origin_coord_); @@ -850,14 +880,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 if constexpr(SrcAddressSpace == AddressSpace::Global) { #if CK_USE_AMD_BUFFER_ADDRESSING - src_vector.template AsType()(Number<0>{}) = + src_tmp_vector.template AsType()(Number<0>{}) = amd_buffer_load_v2( p_src, src_slice_origin_coord_.GetOffset(), is_src_valid, src_desc.GetElementSpaceSize()); #else - src_vector.template AsType()(Number<0>{}) = + src_tmp_vector.template AsType()(Number<0>{}) = is_src_valid ? *reinterpret_cast( &p_src[src_slice_origin_coord_.GetOffset()]) : src_vector_t{0}; @@ -865,17 +895,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 } else { - src_vector.template AsType()(Number<0>{}) = + src_tmp_vector.template AsType()(Number<0>{}) = is_src_valid ? *reinterpret_cast( &p_src[src_slice_origin_coord_.GetOffset()]) : src_vector_t{0}; } + // copy data from src_tmp_vector to buffer_ static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { constexpr index_t buffer_offset = buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector); - buffer_(Number{}) = src_vector.template AsType()[i]; + buffer_(Number{}) = src_tmp_vector.template AsType()[i]; }); constexpr auto move_on_dim = [&]() constexpr @@ -937,10 +968,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 // src scalar per access on each dim // TODO: don't use this constexpr auto dst_scalar_per_access = generate_sequence( - lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_scalar_step_in_vector = - generate_sequence(lambda_scalar_step_in_vector{}, Number{}); + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -1026,20 +1057,21 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 DstInMemOp == InMemoryDataOperation::Set, "wrong! hardcoded for ds_write"); - typename vector_type_maker::type dst_vector; + vector_type_maker_t dst_tmp_vector; + // copy data from buffer_ to dst_tmp_vector static_for<0, DstScalarPerVector, 1>{}([&](auto i) { constexpr index_t buffer_offset = buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); - dst_vector.template AsType()(i) = buffer_[Number{}]; + dst_tmp_vector.template AsType()(i) = buffer_[Number{}]; }); - using DstVectorType = - typename vector_type_maker::type::type; + using dst_vector_t = typename decltype(dst_tmp_vector)::type; - *reinterpret_cast(p_dst + dst_slice_origin_coord_.GetOffset()) = - dst_vector.template AsType()[Number<0>{}]; + // copy data from dst_tmp_vector to dst_buf + *reinterpret_cast(p_dst + dst_slice_origin_coord_.GetOffset()) = + dst_tmp_vector.template AsType()[Number<0>{}]; constexpr auto move_on_dim = [&]() constexpr { @@ -1123,7 +1155,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( - lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; @@ -1185,7 +1217,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access constexpr auto dst_scalar_per_access = generate_sequence( - lambda_scalar_per_access{}, Number{}); + detail::lambda_scalar_per_access{}, Number{}); constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; @@ -1274,7 +1306,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step); } - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) @@ -1297,11 +1328,203 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); - StaticallyIndexedArray buffer_; + StaticBuffer buffer_; SrcCoord src_slice_origin_coord_; DstCoord dst_slice_origin_coord_; }; +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-iterator +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template < + typename SrcData, + typename DstData, + typename SrcDesc, + typename DstDesc, + typename SliceLengths, + typename DimAccessOrder, + index_t SrcVectorDim, + index_t SrcScalarPerVector, + AddressSpace SrcAddressSpace, + AddressSpace DstAddressSpace, + index_t SrcScalarStrideInVector, + typename std::enable_if::type = false> +struct ThreadwiseDynamicTensorSliceTransfer_v4 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4(const Index& src_ref_idx) + : src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert( + is_known_at_compile_time< + remove_cv_t>>::value && + is_known_at_compile_time>>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cv_t>{}; + constexpr auto dst_desc = remove_cv_t>{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access of each dim + constexpr auto src_scalar_per_access = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number{}; + } + else + { + return Number<1>{}; + } + }, + Number{}); + + // scalar step (if steping on SrcVectorDim) of each dim + constexpr auto src_scalar_step_in_vector = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number<1>{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { +#if 0 + // TODO: unable to compile + // position in slice window + constexpr auto data_to_origin_disp_idx = + container_reorder_given_old2new(ordered_access_idx, dim_access_order) * + src_scalar_per_access; +#else + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; +#endif + + // src coordinate + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_iterator = + make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_dynamic_tensor_coordinate( + src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator); + + // copy data from src_buf into src_tmp_buffer + vector_type_maker_t src_tmp_vector; + + using src_vector_t = typename decltype(src_tmp_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + src_tmp_vector.template AsType()(Number<0>{}) = + is_src_valid ? src_buf.template Get(src_data_coord.GetOffset()) + : src_vector_t{0}; + + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + dst_tmp_vector.template AsType()(i) = + type_convert{}(src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator( + src_desc, to_multi_index(src_slice_move_step_idx)); + + move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + + private: + SrcCoord src_ref_coord_; +}; + } // namespace ck #endif diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp index 868f205630..0e69bdbc38 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp @@ -6,100 +6,52 @@ namespace ck { -template -__device__ void threadwise_matrix_set_zero_v2(Desc, Float* __restrict__ p_thread) -{ - static_assert(Desc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - constexpr auto desc = Desc{}; - - constexpr auto M = desc.GetLength(I0); - constexpr auto N = desc.GetLength(I1); - - static_for<0, M, 1>{}([&](auto i) { - static_for<0, N, 1>{}([&](auto j) { - constexpr auto offset = desc.CalculateOffset(make_tuple(i, j)); - - p_thread[offset] = Float(0); - }); - }); -} - -template -struct ThreadwiseMatrixSliceCopy_v2 -{ - template - __device__ static void Run(const Data* p_src, Data* p_dst) - { - static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - using vector_t = typename vector_type_maker::type::type; - - static_for<0, NSliceRow, 1>{}([&](auto i) { - static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) { - constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j)); - constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j)); - - *reinterpret_cast(&p_dst[dst_offset]) = - *reinterpret_cast(&p_src[src_offset]); - }); - }); - } -}; - // C[M, N] += transpose(A[K, M]) * B[K, N] // Element of matrix can be vectorized data -template ::type = false> -struct ThreadwiseGemm_km_kn_mn_v1 +struct ThreadwiseGemm_km_kn_mn_v1r1 { - template - __device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) { static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && CDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; + static_assert( + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - constexpr auto M = CDesc{}.GetLength(I0); - constexpr auto N = CDesc{}.GetLength(I1); - constexpr auto K = ADesc{}.GetLength(I0); - - static_for<0, K, 1>{}([&](auto k) { - static_for<0, M, 1>{}([&](auto m) { - static_for<0, N, 1>{}([&](auto n) { - constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m)); - constexpr auto b_offset = BDesc{}.CalculateOffset(make_tuple(k, n)); - constexpr auto c_offset = CDesc{}.CalculateOffset(make_tuple(m, n)); - - p_c[c_offset] += - inner_product_with_conversion{}(p_a[a_offset], p_b[b_offset]); - }); - }); - }); - } - -#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM - template - __device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) - { - static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && - CDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -110,62 +62,82 @@ struct ThreadwiseGemm_km_kn_mn_v1 constexpr auto N = CDesc{}.GetLength(I1); constexpr auto K = ADesc{}.GetLength(I0); - static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet"); + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); static_for<0, K, 1>{}([&](auto k) { static_for<0, M, 1>{}([&](auto m) { - constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m)); + constexpr index_t a_offset = + ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m)); +#if 0 if constexpr(N == 2) { - constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0)); - constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1)); + constexpr index_t b_offset_0 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0)); + constexpr index_t b_offset_1 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1)); - constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0)); - constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1)); + constexpr index_t c_offset_0 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0)); + constexpr index_t c_offset_1 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1)); - amd_assembly_outer_product_1x2(p_a[a_offset], - p_b[b_offset_0], - p_b[b_offset_1], - p_c[c_offset_0], - p_c[c_offset_1]); + amd_assembly_outer_product_1x2(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{})); } else if constexpr(N == 4) { - constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0)); - constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1)); - constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(k, I2)); - constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(k, I3)); + constexpr index_t b_offset_0 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0)); + constexpr index_t b_offset_1 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1)); + constexpr index_t b_offset_2 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2)); + constexpr index_t b_offset_3 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3)); - constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0)); - constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1)); - constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(m, I2)); - constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(m, I3)); + constexpr index_t c_offset_0 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0)); + constexpr index_t c_offset_1 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1)); + constexpr index_t c_offset_2 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2)); + constexpr index_t c_offset_3 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3)); - amd_assembly_outer_product_1x4(p_a[a_offset], - p_b[b_offset_0], - p_b[b_offset_1], - p_b[b_offset_2], - p_b[b_offset_3], - p_c[c_offset_0], - p_c[c_offset_1], - p_c[c_offset_2], - p_c[c_offset_3]); + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); + } + else +#endif + { + static_for<0, N, 1>{}([&](auto n) { + + constexpr index_t b_offset = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n)); + constexpr index_t c_offset = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n)); + + amd_assembly_inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); + }); } }); }); } -#endif - - template - __device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) - { -#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM - Run_amd_asm(p_a, p_b, p_c); -#else - Run_source(p_a, p_b, p_c); -#endif - } }; } // namespace ck diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp index 54a4932f4d..8c78448e80 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp @@ -6,35 +6,15 @@ namespace ck { -template -__device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread) -{ - static_assert(Desc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto desc = Desc{}; - - constexpr auto K = desc.GetLength(I0); - constexpr auto H = desc.GetLength(I2); - constexpr auto W = desc.GetLength(I3); - - static_for<0, K, 1>{}([&](auto i) { - static_for<0, H, 1>{}([&](auto j) { - static_for<0, W, 1>{}([&](auto k) { - constexpr auto offset = desc.CalculateOffset(make_tuple(i, 0, j, k)); - - p_thread[offset] = Float(0); - }); - }); - }); -} - // C[M, N] += transpose(A[K, M]) * B[K, N] // Element of matrix can be vectorized data -template ::type = false> struct ThreadwiseGemm_km_kn_mn_v3 { - template - __device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) { static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && CDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); + static_assert( + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; @@ -59,79 +63,100 @@ struct ThreadwiseGemm_km_kn_mn_v3 constexpr auto E = ADesc{}.GetLength(I0); constexpr auto K = ADesc{}.GetLength(I1); + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + static_for<0, E, 1>{}([&](auto e) { static_for<0, K, 1>{}([&](auto k) { - constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(e, k)); + constexpr index_t a_offset = + ADesc{}.CalculateOffset(a_origin_idx + make_tuple(e, k)); if constexpr(H == 2 && W == 2) { + constexpr index_t b_offset_0 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0)); + constexpr index_t b_offset_1 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 1)); + constexpr index_t b_offset_2 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0)); + constexpr index_t b_offset_3 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 1)); - constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0)); - constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 1)); - constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0)); - constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 1)); + constexpr index_t c_offset_0 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0)); + constexpr index_t c_offset_1 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 1)); + constexpr index_t c_offset_2 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0)); + constexpr index_t c_offset_3 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 1)); - constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0)); - constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 1)); - constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0)); - constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 1)); - - amd_assembly_outer_product_1x4(p_a[a_offset], - p_b[b_offset_0], - p_b[b_offset_1], - p_b[b_offset_2], - p_b[b_offset_3], - p_c[c_offset_0], - p_c[c_offset_1], - p_c[c_offset_2], - p_c[c_offset_3]); + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); } else if constexpr(H == 4 && W == 1) { + constexpr index_t b_offset_0 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0)); + constexpr index_t b_offset_1 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0)); + constexpr index_t b_offset_2 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 2, 0)); + constexpr index_t b_offset_3 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 3, 0)); - constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0)); - constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0)); - constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 2, 0)); - constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 3, 0)); + constexpr index_t c_offset_0 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0)); + constexpr index_t c_offset_1 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0)); + constexpr index_t c_offset_2 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 2, 0)); + constexpr index_t c_offset_3 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 3, 0)); - constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0)); - constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0)); - constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 2, 0)); - constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 3, 0)); - - amd_assembly_outer_product_1x4(p_a[a_offset], - p_b[b_offset_0], - p_b[b_offset_1], - p_b[b_offset_2], - p_b[b_offset_3], - p_c[c_offset_0], - p_c[c_offset_1], - p_c[c_offset_2], - p_c[c_offset_3]); + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); } else { static_for<0, H, 1>{}([&](auto h) { static_for<0, W, 1>{}([&](auto w) { - constexpr auto b_offset = - BDesc{}.CalculateOffset(make_tuple(e, 0, h, w)); - constexpr auto c_offset = - CDesc{}.CalculateOffset(make_tuple(k, 0, h, w)); - p_c[c_offset] += inner_product_with_conversion{}(p_a[a_offset], - p_b[b_offset]); + constexpr index_t b_offset = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, h, w)); + + constexpr index_t c_offset = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, h, w)); + +#if 0 + c_buf(Number{}) += inner_product_with_conversion{}( + a_buf[Number{}], b_buf[Number{}]); +#else + amd_assembly_inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); +#endif }); }); } }); }); } - - template - __device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) - { - Run_source(p_a, p_b, p_c); - } }; } // namespace ck diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/composable_kernel/include/utility/amd_inline_asm.hpp index 3c8b58193b..b5d2e4e38e 100644 --- a/composable_kernel/include/utility/amd_inline_asm.hpp +++ b/composable_kernel/include/utility/amd_inline_asm.hpp @@ -5,6 +5,75 @@ namespace ck { +// c += inner_product(a, b) +__device__ void amd_assembly_inner_product(const float& a, const float& b, float& c) +{ +#if CK_USE_AMD_V_FMAC_F32 + asm volatile("\n \ + v_fmac_f32 %0, %1, %2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + asm volatile("\n \ + v_mac_f32 %0, %1, %2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#endif +} + +__device__ void amd_assembly_inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) +{ +#if 1 + asm volatile("\n \ + v_dot4_i32_i8 %0, %1, %2, %0\n \ + " + : "=v"(c) + : "v"(as_type(a)), "v"(as_type(b)), "0"(c)); +#else + c = __builtin_amdgcn_sdot4(as_type(a), as_type(b), c, false); +#endif +} + +__device__ void amd_assembly_inner_product(const int8x8_t& a, const int8x8_t& b, int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + amd_assembly_inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_assembly_inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +__device__ void amd_assembly_inner_product(const int8x16_t& a, const int8x16_t& b, int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + amd_assembly_inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_assembly_inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + amd_assembly_inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + amd_assembly_inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) diff --git a/composable_kernel/include/utility/buffer.hpp b/composable_kernel/include/utility/buffer.hpp new file mode 100644 index 0000000000..fbd789b6fd --- /dev/null +++ b/composable_kernel/include/utility/buffer.hpp @@ -0,0 +1,72 @@ +#ifndef CK_BUFFER_HPP +#define CK_BUFFER_HPP + +#include "statically_indexed_array.hpp" + +namespace ck { + +template +struct StaticBuffer : public StaticallyIndexedArray +{ + using type = T; + using base = StaticallyIndexedArray; + + __host__ __device__ constexpr StaticBuffer() : base{} {} + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } +}; + +template +__host__ __device__ constexpr auto make_static_buffer(Number) +{ + return StaticBuffer{}; +} + +template +struct DynamicBuffer +{ + using type = T; + + T* p_data_; + + __host__ __device__ constexpr DynamicBuffer(T* p_data) : p_data_{p_data} {} + + __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } + + template >>::type, + typename scalar_type>>::type>::value, + bool>::type = false> + __host__ __device__ constexpr const auto Get(index_t i) const + { + return *reinterpret_cast(&p_data_[i]); + } + + template >>::type, + typename scalar_type>>::type>::value, + bool>::type = false> + __host__ __device__ void Set(index_t i, const X& x) + { + *reinterpret_cast(&p_data_[i]) = x; + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } +}; + +template +__host__ __device__ constexpr auto make_dynamic_buffer(T* p) +{ + return DynamicBuffer{p}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index 5a26f8958f..6afe465800 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -7,6 +7,7 @@ #include "statically_indexed_array.hpp" #include "container_element_picker.hpp" #include "float_type.hpp" +#include "buffer.hpp" #include "functional.hpp" #include "functional2.hpp" #include "functional3.hpp" diff --git a/composable_kernel/include/utility/config.amd.hpp.in b/composable_kernel/include/utility/config.amd.hpp.in index 9de35587fd..e42b8e5bef 100644 --- a/composable_kernel/include/utility/config.amd.hpp.in +++ b/composable_kernel/include/utility/config.amd.hpp.in @@ -14,11 +14,11 @@ #define CK_DEVICE_BACKEND_AMD 1 // GPU ID -#if 1 +#if 0 #define CK_AMD_GPU_GFX906 1 #elif 0 #define CK_AMD_GPU_GFX908 1 -#elif 0 +#elif 1 #define CK_AMD_GPU_GFX1030 1 #endif @@ -28,7 +28,7 @@ #endif // launch bounds -#define CK_USE_LAUNCH_BOUNDS 0 +#define CK_USE_LAUNCH_BOUNDS 1 #ifdef CK_USE_LAUNCH_BOUNDS #define CK_MAX_THREAD_PER_BLOCK 256 diff --git a/composable_kernel/include/utility/float_type.amd.hpp.in b/composable_kernel/include/utility/float_type.amd.hpp.in index f957f9aaa7..44cf657cb1 100644 --- a/composable_kernel/include/utility/float_type.amd.hpp.in +++ b/composable_kernel/include/utility/float_type.amd.hpp.in @@ -1,6 +1,8 @@ #ifndef CK_FLOAT_TYPE_AMD_HPP #define CK_FLOAT_TYPE_AMD_HPP +#include "statically_indexed_array.hpp" + namespace ck { using half_t = _Float16; @@ -43,6 +45,15 @@ struct vector_type_maker, N0> using type = vector_type; }; +template +using vector_type_maker_t = typename vector_type_maker::type; + +template +__host__ __device__ constexpr auto make_vector_type(Number) +{ + return typename vector_type_maker::type{}; +} + // scalar_type template struct scalar_type; @@ -403,32 +414,249 @@ struct vector_type } }; +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + + using type = d32_t; + + union + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + + using type = d64_t; + + union + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + } +}; + // fp32 -using float2_t = typename vector_type::type; -using float4_t = typename vector_type::type; -using float8_t = typename vector_type::type; +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; +using float16_t = typename vector_type::type; +using float32_t = typename vector_type::type; +using float64_t = typename vector_type::type; // fp16 using half2_t = typename vector_type::type; using half4_t = typename vector_type::type; using half8_t = typename vector_type::type; using half16_t = typename vector_type::type; +using half32_t = typename vector_type::type; +using half64_t = typename vector_type::type; // bfp16 -using ushort2_t = typename vector_type::type; -using ushort4_t = typename vector_type::type; -using ushort8_t = typename vector_type::type; +using ushort2_t = typename vector_type::type; +using ushort4_t = typename vector_type::type; +using ushort8_t = typename vector_type::type; +using ushort16_t = typename vector_type::type; +using ushort32_t = typename vector_type::type; +using ushort64_t = typename vector_type::type; // i32 -using int32x2_t = typename vector_type::type; -using int32x4_t = typename vector_type::type; -using int32x8_t = typename vector_type::type; +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; +using int32x16_t = typename vector_type::type; +using int32x32_t = typename vector_type::type; +using int32x64_t = typename vector_type::type; // i8 using int8x2_t = typename vector_type::type; using int8x4_t = typename vector_type::type; using int8x8_t = typename vector_type::type; using int8x16_t = typename vector_type::type; +using int8x32_t = typename vector_type::type; +using int8x64_t = typename vector_type::type; // data type conversion template diff --git a/composable_kernel/include/utility/sequence_helper.hpp b/composable_kernel/include/utility/sequence_helper.hpp index d0829c8c35..706b231792 100644 --- a/composable_kernel/include/utility/sequence_helper.hpp +++ b/composable_kernel/include/utility/sequence_helper.hpp @@ -5,11 +5,26 @@ namespace ck { +template +__host__ __device__ constexpr auto make_sequence(Number...) +{ + return Sequence{}; +} + +// F returns index_t template __host__ __device__ constexpr auto generate_sequence(F, Number) { return typename sequence_gen::type{}; } +// F returns Number<> +template +__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number) +{ + return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + } // namespace ck #endif diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp index ccb8b29a77..65c4a60dbb 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( constexpr auto C0 = C / Number{}; constexpr auto C1 = Number{}; -#if 1 +#if 0 // run-time variables constexpr auto in_n_hi_wi_c0_desc = make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0)); @@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); -#if 0 +#if 1 // cdata = 16, BlockSize = 64, 16x64x4 constexpr index_t BlockSize = 64; diff --git a/driver/src/conv_driver.cpp b/driver/src/conv_driver.cpp index 2f490a323f..ab9de5b661 100644 --- a/driver/src/conv_driver.cpp +++ b/driver/src/conv_driver.cpp @@ -64,7 +64,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 0 +#elif 1 constexpr index_t N = 1; constexpr index_t C = 16; constexpr index_t HI = 1080; @@ -630,7 +630,7 @@ int main(int argc, char* argv[]) print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvDilations", to_multi_index(ConvDilations{})); -#if 1 +#if 0 using in_data_t = float; constexpr index_t in_vector_size = 1; using acc_data_t = float; @@ -724,23 +724,22 @@ int main(int argc, char* argv[]) LeftPads{}, RightPads{}, nrepeat); -#elif 1 +#elif 0 device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw - - (in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}, - nrepeat); + out_data_t>( + in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}, + nrepeat); #elif 0 device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk