diff --git a/CMakeLists.txt b/CMakeLists.txt index be4efd3dfd..d5d4cc64a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -516,10 +516,6 @@ include_directories(BEFORE ) SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") -if(BUILD_DEV) - add_compile_options(-Werror) - add_compile_options(-Weverything) -endif() message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 93fd306e98..d5bcd6f978 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,6 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index cb4f60764e..a6580c85ce 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -63,6 +63,40 @@ struct MultiplyMultiply } }; +void reshapeBuffer(char* buffer, int N, int K, char* output) { + const int KRepeat = 2; + const int NRepeat = 3; + const int KLane = 4; + const int NLane = 5; + const int KPack = 6; + int N0 = N / (NRepeat * NLane); + int K0 = K / (KRepeat * KLane * KPack); + + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K; ++k) { + int n0 = n / (NRepeat * NLane); + int k0 = k / (KRepeat * KLane * KPack); + int nRel = n % (NRepeat * NLane); + int kRel = k % (KRepeat * KLane * KPack); + + int nIndex = nRel / NLane; + int kIndex = kRel / (KLane * KPack); + int nLaneIndex = nRel % NLane; + int kLaneIndex = (kRel % (KLane * KPack)) / KPack; + int kPackIndex = kRel % KPack; + + int outputIndex = (n0 * K0 + k0) * KRepeat * NRepeat * KLane * NLane * KPack + + nIndex * KRepeat * KLane * KPack + + kIndex * KLane * KPack + + nLaneIndex * KPack + + kLaneIndex * KPack + + kPackIndex; + + output[outputIndex] = buffer[n * K + k]; + } + } +} + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough; @@ -77,10 +111,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| -///###### RRR - ///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; ///###### RCR - < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; + // kernel 1: 256->32x128x128 + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; + // kernel 2: 128->32x128x128 + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; + // clang-format on int main(int argc, char* argv[]) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index 171a232c0f..a1bd1fd1fe 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -305,11 +305,11 @@ struct BlockwiseGemmXdlops_pipeline_v3{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + // make_tuple(n0, I0, I0, Number{}), + // b_block_buf, + // b_thread_desc_, + // make_tuple(n0, I0, k0, I0), + // b_thread_buf); + // }); }); __builtin_amdgcn_sched_barrier(0); @@ -351,7 +351,7 @@ struct BlockwiseGemmXdlops_pipeline_v3{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + // make_tuple(n0, I0, I0, Number{}), + // b_block_buf, + // b_thread_desc_, + // make_tuple(n0, I0, k0, I0), + // b_thread_buf); + // }); }); HotLoopScheduler(); @@ -455,7 +455,7 @@ struct BlockwiseGemmXdlops_pipeline_v3{}; static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; + static constexpr auto BlockSizeNumber = Number{}; + static constexpr index_t NLane = 128; + static constexpr index_t KLane = 2; + static constexpr index_t KRepeat = 4; + static_assert(NLane * KLane == BlockSize); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -200,6 +205,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 return math::integer_least_multiple(N, NPerBlock); } + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_least_multiple(N, NLane); + } + __host__ __device__ static auto CalculateBK0Shuffled(index_t K, index_t KBatch) + { + return math::integer_least_multiple(K, KLane * KPack * KBatch); + } + __host__ __device__ static auto CalculateKPadded(index_t K) { return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; @@ -337,6 +351,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 } } + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NKSWIZZLE_V = BlockSize * KPack; + constexpr index_t NKSWIZZLE_N = Number{}; + return make_naive_tensor_descriptor( + make_tuple(N0, K0, NKSWIZZLE_N), + make_tuple(K0 * NKSWIZZLE_V, NKSWIZZLE_N, I1) + ); + } + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) { @@ -549,7 +573,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 AK0{CalculateAK0Padded(K_, KBatch_)}, BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)} + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_, KBatch_)} { } @@ -588,6 +614,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 index_t BK0; index_t MBlock; index_t NBlock; + // FOR PRESHUFFLE ONLY + index_t BN0Shuffled; + index_t BK0Shuffled; }; // Argument @@ -989,7 +1018,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 { // LDS allocation for A and B: be careful of alignment constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // lds max alignment constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); @@ -997,8 +1026,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 constexpr auto a_block_space_size_aligned = math::integer_least_multiple( a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + // constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + // b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); // LDS allocation for C shuffle in LDS constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = @@ -1007,8 +1036,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + - b_block_space_size_aligned * sizeof(LDSTypeB)), + return math::max(a_block_space_size_aligned * sizeof(LDSTypeA), c_block_size * sizeof(CShuffleDataType)); } @@ -1264,8 +1292,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 { const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); - const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( - problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto b_grid_desc_bpreshuffled = MakeBGridDescriptor_Preshuffled( + problem.BN0Shuffled, problem.BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); @@ -1276,7 +1304,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -1299,7 +1327,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock) / NLane; // lds max alignment constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); @@ -1341,20 +1369,22 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); + // using BThreadClusterLengths = Sequence<1, 1, BlockSize>; + // using BBlockTransferClusterArrangeOrder = Sequence<0, 1, 2>; // B matrix blockwise copy auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, + Sequence<1, KRepeat, KPack * BlockSize>, + Sequence<1, 1, BlockSize>, //BThreadClusterLengths, + Sequence<0, 1, 2>, //BBlockTransferClusterArrangeOrder, BDataType, LDSTypeB, - decltype(b_grid_desc_bk0_n_bk1), + decltype(b_grid_desc_bpreshuffled), decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>,//BBlockTransferSrcAccessOrder, Sequence<0, 1, 2>, BBlockTransferSrcVectorDim, 2, @@ -1365,8 +1395,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 BThreadTransferSrcResetCoordinateAfterRun, true, BlockwiseGemmPipe::GlobalBufferNum>( - b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), + b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, 0, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), @@ -1386,7 +1416,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KRepeat, 0); // Blockwise GEMM pipeline static_assert(std::is_default_constructible_v); @@ -1403,7 +1433,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 a_grid_buf, a_block_buf, a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, + b_grid_desc_bpreshuffled, b_block_desc_bk0_n_bk1, b_blockwise_copy, b_grid_buf, @@ -1673,472 +1703,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 } } - template - __device__ static void Run_2Lds(const ADataType* p_a_grid, - const BDataType* p_b_grid, - DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, - const Problem& problem, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; - Run_2Lds( - p_a_grid, - p_b_grid, - p_ds_grid, - p_c_grid, - p_shared_0, - p_shared_1, - problem, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); - } - - template - __device__ static void Run_2Lds(const ADataType* p_a_grid, - const BDataType* p_b_grid, - DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, - const Problem& problem, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - const Block2CTileMap& block_2_ctile_map) - { - const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( - problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); - const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( - problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) - { - return; - } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - LDSTypeA, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - LDSTypeB, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - auto a_block_buf_ping = make_dynamic_buffer( - static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); - - auto b_block_buf_ping = make_dynamic_buffer( - static_cast(p_shared_0) + - a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - - auto a_block_buf_pong = make_dynamic_buffer( - static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); - - auto b_block_buf_pong = make_dynamic_buffer( - static_cast(p_shared_1) + - a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - - auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); - auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); - - // Blockwise GEMM pipeline - static_assert(std::is_default_constructible_v); - auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; - auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); - - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_bufs, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_bufs, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); - - // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } - } }; } // namespace ck