From bb5bdff61c4eff7754a6f4bf71c9b6958177dfdd Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 30 May 2025 08:39:25 +0000 Subject: [PATCH] remove unnecessary files --- .../gemm_mx_fp8_bpreshuffle.cpp | 359 --- ...peline_xdlops_b_preshuffle_mx_selector.hpp | 94 - ...emm_pipeline_xdlops_b_preshuflle_v1_mx.hpp | 832 ------- ...emm_pipeline_xdlops_b_preshuflle_v3_mx.hpp | 932 ------- ...e_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp | 605 ----- ...e_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp | 2136 ----------------- 6 files changed, 4958 deletions(-) delete mode 100644 example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp delete mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_selector.hpp delete mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx.hpp delete mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v3_mx.hpp delete mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp delete mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp deleted file mode 100644 index 0bd6644725..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_bpreshuffle.cpp +++ /dev/null @@ -1,359 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/utility/blkgemmpipe_scheduler.hpp" -#include "ck/utility/data_type.hpp" -#include "ck/utility/sequence.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/library/utility/host_tensor.hpp" - -template -using S = ck::Sequence; - -using F8 = ck::f8_t; -using F16 = ck::half_t; -using BF16 = ck::bhalf_t; -using F32 = float; -using XDataType = ck::e8m0_bexp_t; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using A0DataType = F8; -using A1DataType = XDataType; -using B0DataType = F8; -using B1DataType = XDataType; -using AccDataType = F32; -using DsDataType = ck::Tuple<>; -using CDataType = BF16; -using CShuffleDataType = CDataType; - -using A0Layout = Row; -using B0Layout = Col; -using CLayout = Row; - -void preShuffleBuffer(const F8* src, F8* dst, int N, int K, int NXdl) -{ - int KPack = 16; - int NLane = NXdl; - int KLane = 64 / NLane; - - int K0 = K / (KLane * KPack); - // K -> K0 KLane KPack - // N -> N0 NLane - // N, K -> N0 K0 KLane NLane KPack - int tempk; - for(int n = 0; n < N; ++n) - { - for(int k = 0; k < K; ++k) - { - int n0 = n / NLane; - int n1 = n % NLane; - - int k0 = k / (KLane * KPack); - tempk = k % (KLane * KPack); - int k1 = tempk / KPack; - int k2 = tempk % KPack; - - int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + - k1 * KPack * NLane + n1 * KPack + k2; - - dst[outputIndex] = src[n * K + k]; - } - } -} - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t ScaleBlockSize = 32; // scaling block size - -constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - -// clang-format off -using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3_BPreShuffle< - A0Layout, B0Layout, CLayout, - A0DataType, A1DataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType, - AElementOp, BElementOp, CElementOp, GemmSpec, - ScaleBlockSize, 256, - 128, 128, 128, - 16, 16, - 16, 16, - 8, 2, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 2, 1, S<1, 32, 1, 8>, 8, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, A0DataType, B0DataType>; -// clang-format on - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - bool flush_cache = true; - - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = K; - ck::index_t StrideB = K; - ck::index_t StrideC = N; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 8) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - flush_cache = std::stoi(argv[7]); - - StrideA = K; - StrideB = K; - StrideC = N; - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 6: M, N, K\n"); - printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); - exit(0); - } - - ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; - ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); - Tensor a_m_k_scale(f_host_tensor_descriptor( - M, (K + ScaleBlockSize - 1) / ScaleBlockSize, Scale_Stride_AM, A0Layout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b_preshuffled(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b_k_n_scale(f_host_tensor_descriptor( - (K + ScaleBlockSize - 1) / ScaleBlockSize, N, Scale_Stride_BN, B0Layout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; - std::cout << "e_m_n: " << c_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - break; - case 2: - a_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b_k_n.GenerateTensorValue(GeneratorTensor_1{}); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_1{}); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_1{}); - break; - case 3: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_1{}); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_1{}); - break; - case 4: - a_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b_k_n.GenerateTensorValue(GeneratorTensor_1{}); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - break; - case 5: - a_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b_k_n.GenerateTensorValue(GeneratorTensor_1{}); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_1{}); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - break; - case 6: - a_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b_k_n.GenerateTensorValue(GeneratorTensor_1{}); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_1{}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - } - - DeviceMem a_device_buf(sizeof(A0DataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem a_scale_device_buf(sizeof(A1DataType) * a_m_k_scale.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(B0DataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem b_scale_device_buf(sizeof(B1DataType) * b_k_n_scale.mDesc.GetElementSpaceSize()); - DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - a_scale_device_buf.ToDevice(a_m_k_scale.mData.data()); - b_scale_device_buf.ToDevice(b_k_n_scale.mData.data()); - -#if 1 - printf("print a_m_k_scale:\n"); - for(int m = 0; m < M; ++m) - { - for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; ++k) - { - printf("%f ", ck::type_convert(a_m_k_scale(m, k))); - } - printf("\n"); - } -#endif - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CElementOp{}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - int NPerXdl = device_op.GetPreShuffleParameters(); - - preShuffleBuffer(b_k_n.mData.data(), b_preshuffled.mData.data(), N, K, NPerXdl); - b_device_buf.ToDevice(b_preshuffled.mData.data()); - - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(a_scale_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(b_scale_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - Scale_Stride_AM, - StrideB, - Scale_Stride_BN, - StrideC, - 1, // KBatch - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; - std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + - sizeof(CDataType) * M * N + - sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize; - - float ave_time = .0; - - if(flush_cache) - { - int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; - - ave_time = invoker.Run(argument, - StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); - } - else - { - ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); - } - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << device_op.GetTypeString() << std::endl; - - if(do_verification) - { - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_m_k, - a_m_k_scale, - b_k_n, - b_k_n_scale, - c_m_n_host_result, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - ref_invoker.Run(ref_argument); - - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - return ck::utils::check_err( - c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2) - ? 0 - : 1; - } - - return 0; -} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_selector.hpp deleted file mode 100644 index 511945d9d6..0000000000 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_selector.hpp +++ /dev/null @@ -1,94 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v3_mx.hpp" - -namespace ck { -template -constexpr auto BlockGemmMXBPreshufflePipeline_Selector() -{ - - // Hardware MX GEMM pipeline - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx{}; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3"); - return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3_mx{}; - } - else - { - std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl; - } -} - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx.hpp deleted file mode 100644 index 133443feaf..0000000000 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v1_mx.hpp +++ /dev/null @@ -1,832 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" - -namespace ck { - -// Naive pipeline with lowest resource request per WGP -// GlobalPrefetchStages: 2 -// LocalPreFillStages: 1 -// LocalPreFetchStages: 1 -// LocalSharedMemoryBuffer: 1 - -template -struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx -{ -}; - -template -struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1_mx - : BlockwiseGemmXdlops_mx_pipeline_base - -{ - - using Base = BlockwiseGemmXdlops_mx_pipeline_base; - using Base::I0; - using Base::I1; - using Base::KRepeat; - using Base::MWaves; - using Base::NWaves; - using Base::WaveSize; - using Base::xdlops_gemm; - - using Base::CalculateCThreadOriginDataIndex; - using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; - using Base::GetCThreadBuffer; - using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; - using Base::GetWaveIdx; - using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; - - using Base::a_block_desc_m0_m1_m2_k; - using Base::b_block_desc_n0_n1_n2_k; - - using Base::AMmaKStride; - using Base::BMmaKStride; - using Base::KThreadChunk; - - using Base::APackedSize; - using Base::BPackedSize; - using Base::ComputePackedSize; - - using AccType = typename Base::AccType; - using Tuple4 = typename Base::Tuple4; - using ComputeTypeA = typename Base::ComputeTypeA; - using ComputeTypeB = typename Base::ComputeTypeB; - - static constexpr index_t PrefetchStages = 2; - static constexpr index_t PrefillStages = 1; - static constexpr index_t GlobalBufferNum = 2; - - template - __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) - { - constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); - constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); - constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; - constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; - - return transform_tensor_descriptor( - TileDesc_M0_M1_M2_K{}, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); - } - - static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = - MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); - - static constexpr auto ScalesPerKBlockSize = - KPerBlock / ScaleBlockSize; // How many mx-vectors per K block - - //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() - static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; - - //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() - static constexpr auto ScalesPerXdlopsRunPerThread = - ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; - - __host__ static constexpr bool BlockHasHotloop(index_t num_loop) - { - return num_loop > PrefetchStages; - } - - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) - { - return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; - } - - template - __device__ void Run( - // ABlockCopy - const AGridDesc& a_grid_desc, - const ABlockDesc& a_block_desc, - ABlockTransfer& a_blockwise_copy, - const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, - const ABlockTransferStep& a_block_copy_step, - // BBlockCopy - const BGridDesc& b_grid_desc, - const BBlockDesc& b_block_desc, - BBlockTransfer& b_blockwise_copy, - const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, - const BBlockTransferStep& b_block_copy_step, - // CThread - CThreadBuffer& c_thread_buf, - // A and B scales - const AScaleGridDesc& a_scale_grid_desc, - AScaleThreadTransfer& a_scale_thread_copy, - const AScaleGridBuffer& a_scale_grid_buf, - const BScaleGridDesc& b_scale_grid_desc, - BScaleThreadTransfer& b_scale_thread_copy, - const BScaleGridBuffer& b_scale_grid_buf, - index_t num_loop) const - { - ignore = b_block_desc; - ignore = b_block_buf; - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - StaticallyIndexedArray{}> b_thread_bufs; - constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); - - auto a_scale_thread_buf = make_static_buffer( - a_scale_thread_desc.GetElementSpaceSize()); - auto b_scale_thread_buf = make_static_buffer( - b_scale_thread_desc.GetElementSpaceSize()); - - StaticallyIndexedArray{}> a_scale_thread_bufs; - StaticallyIndexedArray{}> b_scale_thread_bufs; - - // Global prefetch A1 B1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(I0)); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Prefetch a_scales to buf 0 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); - auto a_scale_thread_buf_copy = - make_static_buffer( - a_scale_thread_desc_copy.GetElementSpaceSize()); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc_copy, - make_tuple(I0, I0), - a_scale_thread_buf_copy); - - a_scale_thread_bufs(I0)(Number{}) = - a_scale_thread_buf_copy[Number<0>{}]; - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); - - // restore row id and advance to the next set of scales - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - make_multi_index(-MPerBlock, ScalesPerKBlockSize)); - - // Prefetch b_scales to buf 0 - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - auto b_scale_thread_buf_copy = - make_static_buffer( - b_scale_thread_desc_copy.GetElementSpaceSize()); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc_copy, - make_tuple(I0, I0), - b_scale_thread_buf_copy); - - b_scale_thread_bufs(I0)(Number{}) = - b_scale_thread_buf_copy[Number<0>{}]; - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); - }); - - // restore col id and advance to the next set of scales - // NWaves * NPerXDL * NRepeat == NPerBlock - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(-NPerBlock, ScalesPerKBlockSize)); - - __builtin_amdgcn_sched_barrier(0); - - // Local prefill A1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); - - // Global prefetch A2 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - - // Prefetch a_scales to buf 1 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); - auto a_scale_thread_buf_copy = - make_static_buffer( - a_scale_thread_desc_copy.GetElementSpaceSize()); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc_copy, - make_tuple(I0, I0), - a_scale_thread_buf_copy); - - a_scale_thread_bufs(I1)(Number{}) = - a_scale_thread_buf_copy[Number<0>{}]; - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); - - // restore row id and advance to the next set of scales - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - make_multi_index(-MPerBlock, ScalesPerKBlockSize)); - - // Prefetch b_scales to buf 1 - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - auto b_scale_thread_buf_copy = - make_static_buffer( - b_scale_thread_desc_copy.GetElementSpaceSize()); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc_copy, - make_tuple(I0, I0), - b_scale_thread_buf_copy); - - b_scale_thread_bufs(I1)(Number{}) = - b_scale_thread_buf_copy[Number<0>{}]; - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); - }); - - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(-NPerBlock, ScalesPerKBlockSize)); - - // Local prefetch A1 - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, Number{}), - a_thread_buf); - }); - }); - }); - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainLoop) - { - // loop over k with the step KPerBlock - index_t i = 0; - do - { - auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(local_read_buf)); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); - - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - - static_assert( - 0 < ScalesPerXdlopsRunPerThread, - "Must have at least one scale per Xdlops per Thread."); - - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; - - // Pack scale_thread_buf into scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs[mfma_reg_buf] - [Number{}]; - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - block_sync_lds(); - - // a thread copy - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = - k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * - xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, Number{}), - a_thread_buf); - }); - }); - }); - - // Prefetch a_scales - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); - auto a_scale_thread_buf_copy = - make_static_buffer( - a_scale_thread_desc_copy.GetElementSpaceSize()); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc_copy, - make_tuple(I0, I0), - a_scale_thread_buf_copy); - - a_scale_thread_bufs(mfma_reg_buf)(Number{}) = - a_scale_thread_buf_copy[Number<0>{}]; - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); - - // restore row id and advance to the next set of scales - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize)); - - // Prefetch b_scales - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - auto b_scale_thread_buf_copy = - make_static_buffer( - b_scale_thread_desc_copy.GetElementSpaceSize()); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc_copy, - make_tuple(I0, I0), - b_scale_thread_buf_copy); - - b_scale_thread_bufs(mfma_reg_buf)(Number{}) = - b_scale_thread_buf_copy[Number<0>{}]; - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); - }); - - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); - }; - - LoopFunc(I0, I1); - LoopFunc(I1, I0); - - i += 2; - } while(i < (num_loop - 2)); - } - - // tail - if constexpr(TailNum == TailNumber::Even) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(I1)); - block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - - // Pack b_scale_thread_buf into b_scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs[I0][Number{}]; - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - block_sync_lds(); - - // a thread copy - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = - k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, Number{}), - a_thread_buf); - }); - }); - }); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - - // Pack b_scale_thread_buf into b_scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs[I1][Number{}]; - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - } - else if constexpr(TailNum == TailNumber::Odd) - { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - - // Pack b_scale_thread_buf into b_scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs[I0][Number{}]; - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - } - } - - // TODO: make this field protected when a_scale_thread_copy_ is moved - // here - static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{})); - - // Is used to copy data from a_scale_grid to a_scale_thread - static constexpr auto a_scale_thread_desc_copy = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); - - // TODO: make this field protected when b_scale_thread_copy_ is moved - // here - static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{})); - - // Is used to copy data from b_scale_grid to b_scale_thread_buf - static constexpr auto b_scale_thread_desc_copy = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); - - protected: - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, Number{}, Number{})); - using Base::a_thread_copy_; - using Base::a_thread_desc_; - using Base::b_thread_copy_; - // using Base::b_thread_desc_; - using Base::c_thread_desc_; - - static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v3_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v3_mx.hpp deleted file mode 100644 index 2f4dca4aeb..0000000000 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuflle_v3_mx.hpp +++ /dev/null @@ -1,932 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" - -namespace ck { - -// Naive pipeline with lowest resource request per WGP -// GlobalPrefetchStages: 2 -// LocalPreFillStages: 1 -// LocalPreFetchStages: 1 -// LocalSharedMemoryBuffer: 1 - -template -struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3_mx -{ -}; - -template -struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3_mx - : BlockwiseGemmXdlops_mx_pipeline_base - -{ - - using Base = BlockwiseGemmXdlops_mx_pipeline_base; - using Base::I0; - using Base::I1; - using Base::I2; - using Base::KRepeat; - using Base::MWaves; - using Base::NWaves; - using Base::WaveSize; - using Base::xdlops_gemm; - using typename Base::HotLoopInstList; - - using Base::CalculateCThreadOriginDataIndex; - using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; - using Base::GetCThreadBuffer; - using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; - using Base::GetWaveIdx; - using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; - using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; - - using Base::a_block_desc_m0_m1_m2_k; - using Base::b_block_desc_n0_n1_n2_k; - - using Base::AMmaKStride; - using Base::BMmaKStride; - using Base::KThreadChunk; - - using Base::APackedSize; - using Base::BPackedSize; - using Base::ComputePackedSize; - - using AccType = typename Base::AccType; - using Tuple4 = typename Base::Tuple4; - using ComputeTypeA = typename Base::ComputeTypeA; - using ComputeTypeB = typename Base::ComputeTypeB; - - static constexpr index_t PrefetchStages = 2; - static constexpr index_t PrefillStages = 1; - static constexpr index_t GlobalBufferNum = 2; - static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; - - template - __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) - { - constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); - constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); - constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; - constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; - - return transform_tensor_descriptor( - TileDesc_M0_M1_M2_K{}, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); - } - - static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = - MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); - - static constexpr auto ScalesPerKBlockSize = - KPerBlock / ScaleBlockSize; // How many mx-vectors per K block - - //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() - static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; - - //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() - static constexpr auto ScalesPerXdlopsRunPerThread = - ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; - - __host__ static constexpr bool BlockHasHotloop(index_t num_loop) - { - return num_loop > PrefetchStages; - } - - __device__ static constexpr auto HotLoopScheduler() - { - // A/B split schedule - // compiler is likely to use ds_read2 when instruction width smaller than 16bytes - constexpr auto num_ds_read_inst_a = - HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 - ? HotLoopInstList::A_LDS_Read_Inst_Num - : HotLoopInstList::A_LDS_Read_Inst_Num / 2; - constexpr auto num_ds_read_inst_b = - HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 - ? HotLoopInstList::B_LDS_Read_Inst_Num - : HotLoopInstList::B_LDS_Read_Inst_Num / 2; - - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; - - constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; - constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; - - constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; - constexpr auto ds_read_a_issue_cycle = - HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; - constexpr auto ds_read_b_issue_cycle = - HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; - constexpr auto ds_read_a_mfma_rate = - (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); - constexpr auto ds_read_b_mfma_rate = - (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); - - constexpr auto num_dsread_a_mfma = - (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; - constexpr auto num_dsread_b_mfma = - (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; - - // stage 1 - // Separate this part? - // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > - // sizeof(ComputeDataType) / sizeof(BDataType) - // ? sizeof(ComputeDataType) / sizeof(ADataType) - // : sizeof(ComputeDataType) / sizeof(BDataType); - constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); - constexpr auto num_mfma_per_issue = - num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); - constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; - constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; - - static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { - ignore = i; - static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { - ignore = idswrite; - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier( - 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA - }); - static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { - ignore = i; - static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { - ignore = idswrite; - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier( - 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA - }); - - // stage 2 - static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { - if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= - ds_read_a_mfma_rate) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier(0x100, - num_ds_read_inst_a - (num_dsread_a_mfma - 1) * - ds_read_a_mfma_rate, - 0); // DS read - } - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - - static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { - if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= - ds_read_b_mfma_rate) - { - __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier(0x100, - num_ds_read_inst_b - (num_dsread_b_mfma - 1) * - ds_read_b_mfma_rate, - 0); // DS read - } - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - } - - __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) - { - return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; - } - - template - __device__ void Run( - // ABlockCopy - const AGridDesc& a_grid_desc, - const ABlockDesc& a_block_desc, - ABlockTransfer& a_blockwise_copy, - const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, - const ABlockTransferStep& a_block_copy_step, - // BBlockCopy - const BGridDesc& b_grid_desc, - const BBlockDesc& b_block_desc, - BBlockTransfer& b_blockwise_copy, - const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, - const BBlockTransferStep& b_block_copy_step, - // CThread - CThreadBuffer& c_thread_buf, - // A and B scales - const AScaleGridDesc& a_scale_grid_desc, - AScaleThreadTransfer& a_scale_thread_copy, - const AScaleGridBuffer& a_scale_grid_buf, - const BScaleGridDesc& b_scale_grid_desc, - BScaleThreadTransfer& b_scale_thread_copy, - const BScaleGridBuffer& b_scale_grid_buf, - index_t num_loop) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - StaticallyIndexedArray{}> b_thread_bufs; - constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); - - auto a_scale_thread_buf = make_static_buffer( - a_scale_thread_desc.GetElementSpaceSize()); - auto b_scale_thread_buf = make_static_buffer( - b_scale_thread_desc.GetElementSpaceSize()); - - StaticallyIndexedArray{}> a_scale_thread_bufs; - StaticallyIndexedArray{}> b_scale_thread_bufs; - - // Global prefetch B1 - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(I0)); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Global prefetch A1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - - // Prefetch a_scales 1 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); - auto a_scale_thread_buf_copy = - make_static_buffer( - a_scale_thread_desc_copy.GetElementSpaceSize()); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc_copy, - make_tuple(I0, I0), - a_scale_thread_buf_copy); - - a_scale_thread_bufs(I0)(Number{}) = - a_scale_thread_buf_copy[Number<0>{}]; - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); - // restore row id and advance to the next set of scales - a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, - make_multi_index(-MPerBlock, ScalesPerKBlockSize)); - - // Prefetch b_scales 1 - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - auto b_scale_thread_buf_copy = - make_static_buffer( - b_scale_thread_desc_copy.GetElementSpaceSize()); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc_copy, - make_tuple(I0, I0), - b_scale_thread_buf_copy); - - b_scale_thread_bufs(I0)(Number{}) = - b_scale_thread_buf_copy[Number<0>{}]; - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); - }); - // restore col id and advance to the next set of scales - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(-NPerBlock, ScalesPerKBlockSize)); - - // Local prefill A1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); // vmem->vgpr-> lds0 - - // Global prefetch A2 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - - // Initialize C - c_thread_buf.Clear(); - - // Local prefetch A1 - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, k, Number{}), - a_thread_buf); - }); - }); - }); - - // main body - if constexpr(HasMainLoop) - { - // loop over k with the step KPerBlock - index_t i = 0; - do - { - auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf, auto a_buf) { - // Prefetch a_scales 2 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); - auto a_scale_thread_buf_copy = - make_static_buffer( - a_scale_thread_desc_copy.GetElementSpaceSize()); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc_copy, - make_tuple(I0, I0), - a_scale_thread_buf_copy); - - a_scale_thread_bufs(local_read_buf)(Number{}) = - a_scale_thread_buf_copy[Number<0>{}]; - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); - // restore row id and advance to the next set of scales - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize)); - - // Prefetch b_scales 2 - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - auto b_scale_thread_buf_copy = - make_static_buffer( - b_scale_thread_desc_copy.GetElementSpaceSize()); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc_copy, - make_tuple(I0, I0), - b_scale_thread_buf_copy); - - b_scale_thread_bufs(local_read_buf)(Number{}) = - b_scale_thread_buf_copy[Number<0>{}]; - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); - }); - // restore col id and advance to the next set of scales - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); - - // Local prefill A2 - block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); - - // Global prefetch A1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - - // Global prefetch B2 - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(local_read_buf)); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // A1 * B1 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - - vector_type - a_scale_thread_vec; - vector_type - b_scale_thread_vec; - - // Pack scale_thread_buf into scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs[mfma_reg_buf] - [Number{}]; - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); // KRepeat - }); // NRepeat - }); // MRepeat - - // Local prefetch A2 - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = - k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}( - [&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * - xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple(m0, I0, k, Number{}), - a_thread_buf); - }); - }); - }); - - HotLoopScheduler(); - __builtin_amdgcn_sched_barrier(0); - }; // LoopFunc - - LoopFunc(I0, I1, I0); - LoopFunc(I1, I0, I1); - - i += 2; - } while(i < (num_loop - 2)); - } - - // tail - if constexpr(TailNum == TailNumber::Even) - { - // Prefetch a_scales 2 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); - auto a_scale_thread_buf_copy = - make_static_buffer( - a_scale_thread_desc_copy.GetElementSpaceSize()); - a_scale_thread_copy.Run(a_scale_grid_desc, - a_scale_grid_buf, - a_scale_thread_desc_copy, - make_tuple(I0, I0), - a_scale_thread_buf_copy); - - a_scale_thread_bufs(I1)(Number{}) = - a_scale_thread_buf_copy[Number<0>{}]; - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); - - // Prefetch b_scales 2 - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - constexpr auto b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); - auto b_scale_thread_buf_copy = - make_static_buffer( - b_scale_thread_desc_copy.GetElementSpaceSize()); - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc_copy, - make_tuple(I0, I0), - b_scale_thread_buf_copy); - - b_scale_thread_bufs(I1)(Number{}) = - b_scale_thread_buf_copy[Number<0>{}]; - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, - make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); - }); - }); - b_scale_thread_copy.MoveSrcSliceWindow( - b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); - }); - - // Local prefill A2 - block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); - - // Global prefetch B2 - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(I1)); - - // A1 * B1 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - - // Pack b_scale_thread_buf into b_scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs[I0][Number{}]; - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); // KRepeat - }); // NRepeat - }); // MRepeat - - // Local prefetch A2 - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto k_step = - k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { - constexpr auto a_k_step_chunk = - k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(m0, I0, k, Number{}), - a_thread_buf); - }); - }); - }); - - // A2 * B2 - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; - }); - - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - - // Pack b_scale_thread_buf into b_scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs[I1][Number{}]; - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs[I1][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); // KRepeat - }); // NRepeat - }); // MRepeat - } - else if constexpr(TailNum == TailNumber::Odd) - { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; - }); - - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - - vector_type a_scale_thread_vec; - vector_type b_scale_thread_vec; - - // Pack b_scale_thread_buf into b_scale_thread_vec - static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { - a_scale_thread_vec.template AsType()(s) = - a_scale_thread_bufs[I0][Number{}]; - b_scale_thread_vec.template AsType()(s) = - b_scale_thread_bufs[I0][Number{}]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // MFMA accumulation - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - a_scale_thread_vec.template AsType(), - b_thread_vec.template AsType(), - b_scale_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); // KRepeat - }); // NRepeat - }); // MRepeat - } - } - - // TODO: make this field protected when a_scale_thread_copy_ is moved - // here - static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{})); - - // Is used to copy data from a_scale_grid to a_scale_thread - static constexpr auto a_scale_thread_desc_copy = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); - - // TODO: make this field protected when b_scale_thread_copy_ is moved - // here - static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{})); - - // Is used to copy data from b_scale_grid to b_scale_thread_buf - static constexpr auto b_scale_thread_desc_copy = - make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); - - protected: - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, Number{}, Number{})); - using Base::a_thread_copy_; - using Base::a_thread_desc_; - using Base::b_thread_copy_; - // using Base::b_thread_desc_; - using Base::c_thread_desc_; - - static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp deleted file mode 100644 index dda0e4906a..0000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp +++ /dev/null @@ -1,605 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/utility/common_header.hpp" - -#include "ck/host_utility/flush_cache.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_b_preshuffle.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// clang-format off -/** - * \brief WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types - * - * This class is a work-in-progress implementation of the XDL CShuffle V3 GEMM for - * microscale-compliant data types. - * - * Assumptions: - * - A and B data types are compliant with the OCP Microscaling Formats (MX) Specification - * - Each scale applies to ScaleBlockSize elements in K direction - * - A scale matrix is a row-major - * - B scale matrix is a column-major - * - Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the - * exponent will be interpreted as conventional biased Float32 exponent (E8M0) - * - * Tunable parameters. - * The CK instance includes a series of tunable template parameters to control the parallel - * granularity of the workload to achieve load balancing on different hardware platforms. These - * parameters include Block Size, M/N/K Per Block, M/N per XDL, AK1, BK1, etc. - * - Block Size determines the number of threads in the thread block. - * - M/N/K Per Block determines the size of tile that each thread block is responsible for - * calculating. - * - M/N Per XDL refers to M/N size for Instinct accelerator Matrix Fused Multiply Add (MFMA) - * instructions operating on a per-wavefront basis. - * - A/B K1 is related to the data type. It can be any value ranging from 1 to K Per Block. To - * achieve the optimal load/store performance, 128bit per load is suggested. In addition, the A/B - * loading parameters must be changed accordingly to match the A/B K1 value; otherwise, it will - * result in compilation errors. - * - * Conditions for achieving computational load balancing on different hardware platforms can vary. - * - * Serialized version of the algorithm: - * \code - * // E = A * B + C - * // Loop over E[MPerBlock,NPerBlock] tiles - * for(int mb = 0; mb < M; mb += MPerBlock){ - * for(int nb = 0; nb < N; nb += NPerBlock){ - * // initialize E[MPerBlock,NPerBlock] tile - * for(int mt = mb; mt < mb + MPerBlock; mt++){ - * for(int nt = nb; nt < nb + NPerBlock; nt++){ - * E[mt,nt] = C[mt,nt]; - * } - * } - * - * // multiply-accumulate per tile - * for(int kb = 0; kb < K; kb += KPerBlock){ - * for(int m0 = mb; m0 < mb + MPerBlock; m0 += MWaves * MPerXDL){ - * for(int n0 = nb; n0 < nb + NPerBlock; n0 += NWaves * NPerXDL){ - * for(int mw = m0; mw < m0 + MWaves * MPerXDL; mw += MPerXDL){ - * for(int nw = n0; nw < n0 + NWaves * NPerXDL; nw += NPerXDL){ - * for(int k0 = kb; k0 < kb + KPerBlock; k0 += mfma.num_input_blks*KPack){ - * // MFMA accumulation - * for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPerXdlops){ - * // MFMA instruction - * for(int k_mfma = k_pack; k_mfma < k_pack + KPerXdlops; k_mfma += mfma.k_per_blk){ - * for(int m = mw; m < mw + MPerXDL; m++){ - * for(int n = nw; n < nw + NPerXDL; n++){ - * for(int k = k_mfma; k < k_mfma + mfma.k_per_blk; k++){ - * E[m,n] += A[m,k] * B[k,n]; - * } - * } - * } - * } - * } - * } - * } - * } - * } - * } - * } - * } - * } - * \endcode - * - */ -// clang-format on -template -struct DeviceGemmMX_Xdl_CShuffleV3_BPreShuffle - : public DeviceGemmMX_BPreshuffle -{ - // GridwiseGemm - using GridwiseGemm = GridwiseGemmMX_xdl_cshuffle_v3_b_preshuffle< - ALayout, - BLayout, - CLayout, - ADataType, - AScaleDataType, - BDataType, - BScaleDataType, - GemmAccDataType, - CShuffleDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - ScaleBlockSize, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ComputeTypeA, - ComputeTypeB>; - - using Argument = typename GridwiseGemm::Argument; - - int GetPreShuffleParameters() override { return NPerXDL; } - - // Invoker - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } - - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); - - auto size_a_buffer = - a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); - auto size_b_buffer = - b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); - - ck::utility::RotatingMemWrapper rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, - 0, - arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_); - } - else - { - if(arg.KBatch > 1) - hipGetErrorString(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); - - ave_time = launch_and_time_kernel( - stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); - } - }; - - // TODO: Check if this is the right algorithm for minimum_occupancy - constexpr index_t minimum_occupancy = - BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave - ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && - MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) - ? 2 - : 1 - : 2; - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_b_preshuffle; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_b_preshuffle; - Run(kernel); - } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< - GridwiseGemm, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< - GridwiseGemm, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_b_preshuffle; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_b_preshuffle; - Run(kernel); - } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< - GridwiseGemm, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< - GridwiseGemm, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - static_assert(is_scale_mfma_data_type() && is_scale_mfma_data_type(), - "Only microscaling formats are supported for ADataType and BDataType"); - - static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported"); - - static_assert(is_same_v && is_same_v, - "ComputeTypeA and ComputeTypeB must be the same as ADataType and BDataType"); - - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - if constexpr(!IsValidCompilationParameter()) - { - return false; - } - - if(!ck::is_xdl_supported()) - { - return false; - } - - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) - { - return false; - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding || - GemmSpec == GemmSpecialization::KPadding)) - { - return false; - } - - return GridwiseGemm::CheckValidity(arg); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const ADataType* p_a, - const AScaleDataType* p_a_scale, - const BDataType* p_b, - const BScaleDataType* p_b_scale, - CDataType* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideScaleA, - index_t StrideB, - index_t StrideScaleB, - index_t StrideC, - index_t KBatch, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - return Argument{p_a, - p_a_scale, - p_b, - p_b_scale, - p_c, - M, - N, - K, - StrideA, - StrideScaleA, - StrideB, - StrideScaleB, - StrideC, - KBatch, - a_element_op, - b_element_op, - c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_a_scale, - const void* p_b, - const void* p_b_scale, - void* p_c, - ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t StrideA, - ck::index_t StrideScaleA, - ck::index_t StrideB, - ck::index_t StrideScaleB, - ck::index_t StrideC, - ck::index_t KBatch, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) override - { - return std::make_unique(static_cast(p_a), - static_cast(p_a_scale), - static_cast(p_b), - static_cast(p_b_scale), - static_cast(p_c), - M, - N, - K, - StrideA, - StrideScaleA, - StrideB, - StrideScaleB, - StrideC, - KBatch, - a_element_op, - b_element_op, - c_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - std::map BlkGemmPipelineSchedulerToString{ - {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, - {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; - - std::map BlkGemmPipelineVersionToString{ - {BlockGemmPipelineVersion::v1, "v1"}, - {BlockGemmPipelineVersion::v2, "v2"}, - {BlockGemmPipelineVersion::v3, "v3"}, - {BlockGemmPipelineVersion::v4, "v4"}, - {BlockGemmPipelineVersion::v5, "v5"}}; - - // clang-format off - str << "DeviceGemmMX_Xdl_CShuffleV3" - << "<" - << getGemmSpecializationString(GemmSpec) << ", " - << std::string(ALayout::name)[0] - << std::string(BLayout::name)[0] - << std::string(CLayout::name)[0] - << ">" - << " BlkSize: " - << BlockSize << ", " - << "BlkTile: " - << MPerBlock<<"x"< -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg); - -#else - ignore = karg; -#endif // end of if (defined(__gfx9__)) -} - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - GridwiseGemm::template Run_2Lds( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared_0, - p_shared_1, - karg); - -#else - ignore = karg; -#endif // end of if (defined(__gfx9__)) -} - -template -struct GridwiseGemmMX_xdl_cshuffle_v3_b_preshuffle -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - - static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); - static constexpr bool is_single_rate_mfma = false; - static constexpr auto is_scale_mfma = true; - - //> KPack is at least the k_per_blk of selected mfma - // - // Should be a multiple of k_per_blk. - // TODO: Move this to blockwise pipeline base - using mfma_selector = MfmaSelector; - static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma_selector::selected_mfma.k_per_blk); - - static constexpr index_t KGroup = 1; // mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; - static constexpr index_t KLane = - mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); - static constexpr index_t NLane = NPerXdl; - static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; - - using ThisThreadBlock = ThisThreadBlock; - - static constexpr index_t APackedSize = []() { - if constexpr(is_same_v, pk_i4_t> || - is_same_v, f4x2_pk_t>) - return 2; - else - return 1; - }(); - - static constexpr index_t BPackedSize = []() { - if constexpr(is_same_v, pk_i4_t> || - is_same_v, f4x2_pk_t>) - return 2; - else - return 1; - }(); - - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) - { - return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); - } - - __host__ static auto CalculateMPadded(index_t M) - { - return math::integer_least_multiple(M, MPerBlock); - } - - __host__ static auto CalculateNPadded(index_t N) - { - return math::integer_least_multiple(N, NPerBlock); - } - - __host__ __device__ static auto CalculateBN0Shuffled(index_t N) - { - return math::integer_divide_ceil(N, NLane); - } - __host__ __device__ static auto CalculateBK0Shuffled(index_t K) - { - return math::integer_divide_ceil(K, KLane * KPack / KGroup); - } - - __host__ static auto CalculateKPadded(index_t K) - { - return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; - } - - __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) - { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); - } - - __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) - { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); - } - - __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) - { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * KPerBlock; - } - - __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) - { - constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); - auto K_t = K_Batch * KReadVec; - return (K + K_t - 1) / K_t * KReadVec; - } - - __host__ static auto CalculateMBlock(index_t M) - { - return math::integer_divide_ceil(M, MPerBlock); - } - - __host__ static auto CalculateNBlock(index_t N) - { - return math::integer_divide_ceil(N, NPerBlock); - } - - template - __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) - { - constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); - constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); - - return transform_tensor_descriptor( - TileDesc_K0_MN_K1{}, - make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}))), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}), - make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); - } - - __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( - index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) - { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - if constexpr(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both M and K - const auto a_grid_desc_m_k = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(MPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad M, but not K - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_right_pad_transform(M, MPad - M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad K, but not M - const auto a_grid_desc_m_k = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else - { - // not pad M or K - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - } - - __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) - { - constexpr index_t NkSwizzleNumber = Number{}; - return make_naive_tensor_descriptor( - make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), - make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); - } - - __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( - index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); - } - }(); - - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - static_assert(!(is_same_v, pk_i4_t> && - GemmSpec != GemmSpecialization::Default), - "pk_i4_t does not support padding"); - static_assert(!(is_same_v, f4x2_pk_t> && - GemmSpec != GemmSpecialization::Default), - "f4x2_pk_t does not support padding"); - - if constexpr(GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both N and K - const auto b_grid_desc_n_k = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_right_pad_transform(N, NPad - N), - make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(NPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad N, but not K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad K, but not N - const auto b_grid_desc_n_k = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - if constexpr(!PermuteB) - { - // not pad N or K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - // Weight Tile Permute - constexpr index_t BK01 = KPerBlock / BK1Value; - // const index_t BK00 = BK0 / BK01; - const index_t BK0_ = StrideB / BK1Value; - const index_t BK00 = BK0_ / BK01; - - const auto b_grid_desc_bk00_n_bk01_bk1_permute = - make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); - - const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( - b_grid_desc_bk00_n_bk01_bk1_permute, - make_tuple(make_merge_transform(make_tuple(BK00, BK01)), - make_pass_through_transform(make_tuple(N)), - make_pass_through_transform(BK1Value)), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_grid_desc_bk0_n_bk1_permute; - } - } - } - - template - __host__ __device__ static constexpr auto - MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) - { - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); - } - - template - __host__ __device__ static constexpr auto - MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) - { - return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); - } - - __host__ __device__ static auto - MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) - { - const auto c_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - // pad M and N - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); -#if 0 - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - if constexpr(GemmSpec == GemmSpecialization::MNPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad M and N - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad M, but not N - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad N, but not M - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - // not pad M or N - return c_grid_desc_mraw_nraw; - } -#endif - } - - struct Problem - { - __host__ Problem(index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideScaleA_, - index_t StrideB_, - index_t StrideScaleB_, - index_t StrideC_, - index_t KBatch_) - : M{M_}, - N{N_}, - K{K_}, - StrideA{StrideA_}, - StrideScaleA{StrideScaleA_}, - StrideB{StrideB_}, - StrideScaleB{StrideScaleB_}, - StrideC{StrideC_}, - KBatch{KBatch_}, - MPadded{CalculateMPadded(M_)}, - NPadded{CalculateNPadded(N_)}, - KRead{CalculateKRead(K_, KBatch_)}, - KPadded{CalculateKPadded(K_, KBatch_)}, - AK0{CalculateAK0Padded(K_, KBatch_)}, - BK0{CalculateBK0Padded(K_, KBatch_)}, - MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)}, - BN0Shuffled{CalculateBN0Shuffled(N_)}, - BK0Shuffled{CalculateBK0Shuffled(K_)} - { - } - - __host__ void Print() const - { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; - } - - index_t M; - index_t N; - index_t K; - index_t StrideA; - index_t StrideScaleA; - index_t StrideB; - index_t StrideScaleB; - index_t StrideC; - index_t KBatch; - index_t MPadded; - index_t NPadded; - index_t KRead; - index_t KPadded; - index_t AK0; - index_t BK0; - index_t MBlock; - index_t NBlock; - // For Preshuffle Only - index_t BN0Shuffled; - index_t BK0Shuffled; - }; - - // Argument - struct Argument : public tensor_operation::device::BaseArgument, public Problem - { - __host__ Argument(const ADataType* p_a_grid_, - const AScaleDataType* p_a_scale_grid_, - const BDataType* p_b_grid_, - const BScaleDataType* p_b_scale_grid_, - CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideScaleA_, - index_t StrideB_, - index_t StrideScaleB_, - index_t StrideC_, - index_t k_batch_, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - CElementwiseOperation c_element_op_, - bool is_reduce_ = false) - : Problem{M_, - N_, - K_, - StrideA_, - StrideScaleA_, - StrideB_, - StrideScaleB_, - StrideC_, - k_batch_}, - p_a_grid{p_a_grid_}, - p_a_scale_grid{p_a_scale_grid_}, - p_b_grid{p_b_grid_}, - p_b_scale_grid{p_b_scale_grid_}, - p_c_grid{p_c_grid_}, - a_element_op{a_element_op_}, - b_element_op{b_element_op_}, - c_element_op{c_element_op_}, - is_reduce(is_reduce_) - { - } - - __host__ __device__ inline bool IsReduceAdd() const - { - return (Problem::KBatch > 1) && is_reduce; - } - - __host__ __device__ inline bool IsAtomicAdd() const - { - return (Problem::KBatch > 1) && (!is_reduce); - } - - const ADataType* p_a_grid; - const AScaleDataType* p_a_scale_grid; - const BDataType* p_b_grid; - const BScaleDataType* p_b_scale_grid; - CDataType* p_c_grid; - - const AElementwiseOperation a_element_op; - const BElementwiseOperation b_element_op; - const CElementwiseOperation c_element_op; - bool is_reduce; - }; - - struct SplitKBatchOffset - { - - __device__ SplitKBatchOffset(Argument& karg, index_t k_id) - { - if constexpr(is_same_v) - { - a_k_split_offset = k_id * karg.KRead / APackedSize; - } - else if constexpr(is_same_v) - { - a_k_split_offset = k_id * karg.KRead * karg.StrideA; - } - - if constexpr(is_same_v) - { - b_k_split_offset = k_id * karg.KRead * karg.StrideB; - } - else if constexpr(is_same_v) - { - if constexpr(!PermuteB) - { - b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize; - } - else - { - const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = k_id * k0_offset / BPackedSize; - } - } - - // Calculate A scale offset - if constexpr(is_same_v) - { - a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize; - } - else if constexpr(is_same_v) - { - a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize * karg.StrideScaleA; - } - - // Calculate B scale offset - if constexpr(is_same_v) - { - b_scale_k_split_offset = k_id * (karg.KRead / ScaleBlockSize) * karg.StrideScaleB; - } - else if constexpr(is_same_v) - { - b_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize; - } - - if(k_id < (karg.KBatch - 1)) - { - karg.K = karg.KRead; - } - else - { - karg.K = karg.K - karg.KRead * (karg.KBatch - 1); - } - - if(karg.IsReduceAdd()) - { - c_reduce_offset = k_id * karg.M * karg.N; - } - else - { - c_reduce_offset = 0; - } - } - - index_t a_k_split_offset; - index_t b_k_split_offset; - index_t a_scale_k_split_offset; // New member for scale matrix offset - index_t b_scale_k_split_offset; // New member for scale matrix offset - index_t c_reduce_offset; - }; - - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return a_lds_block_desc_permuted; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto WaveSize = 64; - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack - return make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, Number{}, Number{})); - } - - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max(a_block_space_size_aligned * sizeof(ADataType) / APackedSize, - c_block_size * sizeof(CShuffleDataType)); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ static constexpr bool CheckValidity(const Argument& karg) - { - static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && - (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, - "Invalid tuning param!"); - - static_assert(KPerBlock % ScaleBlockSize == 0, - "KPerBlock should be multiple of ScaleBlockSize"); - - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && - !(is_same::value)) - { - if(!(karg.M % MPerBlock == 0)) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && - (is_same::value)) - { - if(!(karg.N % NPerBlock == 0)) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) - { - auto K_t = karg.KBatch * KPerBlock; - if(!(karg.K % K_t == 0)) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " - << karg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; - } - } - else - { - constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); - auto K_t = karg.KBatch * KReadVec; - auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; - if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) - { - return false; - } - } - - if constexpr(is_same::value) - { - if(karg.K % ABlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - else - { - if(karg.M % ABlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - - if constexpr(is_same::value) - { - if(karg.N % BBlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - else - { - if(karg.K % BBlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - - if constexpr(is_same::value) - { - if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - else - { - if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - - if constexpr(!(is_same, half_t>::value || - is_same, float>::value || - is_same, bhalf_t>::value || - is_same, int32_t>::value)) - { - if(!karg.IsReduceAdd()) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__ << std::endl; - } - if(karg.KBatch > 1) - { - return false; - } - } - } - - // check gridwise gemm pipeline -#if 0 - const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); - - if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) - { - return false; - } -#endif - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) - { - const index_t num_loop = K / KPerBlock; - - return BlockwiseGemmPipe::BlockHasHotloop(num_loop); - } - - __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) - { - const index_t num_loop = K / KPerBlock; - - return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); - } - - template - __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) - { - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), - make_unmerge_transform(make_tuple(NBlock, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - return c_grid_desc_mblock_mperblock_nblock_nperblock; - } - - // return block_id to C matrix tile idx (m0, n0) mapping - // if arch = gfx942 - using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; - // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; - - template - __device__ static void Run(const ADataType* p_a_grid, - const AScaleDataType* p_a_scale_grid, - const BDataType* p_b_grid, - const BScaleDataType* p_b_scale_grid, - CDataType* p_c_grid, - void* p_shared, - const Problem& problem, - const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, - const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak, - const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled, - const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - // A Scale buffer - const auto a_scale_grid_buf = make_dynamic_buffer( - p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); - - // B Scale buffer - const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); - - const AElementwiseOperation a_element_op{}; - // const BElementwiseOperation b_element_op{}; - const CElementwiseOperation c_element_op{}; - - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) - { - return; - } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - // Thread-wise copy - // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack - auto b_block_buf = make_static_buffer( - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - - auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< - BDataType, - BDataType, - decltype(b_grid_desc_bpreshuffled), - decltype(b_block_desc_bk0_n_bk1), - Sequence{}, I1, Number{}, Number{}>, - Sequence<1, 2, 0, 3>, - 3, - BBlockTransferSrcScalarPerVector, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_grid_desc_bpreshuffled, - make_multi_index(n_block_data_idx_on_grid, - get_warp_local_1d_id() % NWave, - 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); - - // LDS allocation for A and B: be careful of alignment - - // Cast after lds - auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), - a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); - - // Blockwise GEMM pipeline - static_assert(std::is_default_constructible_v); - auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; - auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); - - // Initial thread mapping for: - // BlockSize = 256 - // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2 - // For each [m0, n0] tile, there are 4 waves: - // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] - // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] - // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] - // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] - - // BlockSize = 128 - // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1 - // For each [m0, n0] tile, there are 2 waves: - // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0] - // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0] - - // TODO: Document initial thread mapping for more combinations of parameters - - const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); - const auto waveId_m = wave_idx[I0]; - const auto waveId_n = wave_idx[I1]; - - static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; - - auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / - mfma.selected_mfma.num_threads_per_blk; - - auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl; - - auto a_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - 1, // SrcScalarPerVector - 1, // SrcScalarStrideInVector - true>( - a_scale_grid_desc_am_ak, - make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, thread_offset_k)); - - auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl; - - auto b_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - 1, // SrcScalarPerVector - 1, - true>( - b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k)); - - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bpreshuffled, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - c_thread_buf, - a_scale_grid_desc_am_ak, - a_scale_thread_copy, - a_scale_grid_buf, - b_scale_grid_desc_bn_ak, - b_scale_thread_copy, - b_scale_grid_buf, - num_k_block_main_loop); - - // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } - } - - template - __device__ static void Run(const ADataType* p_a_grid, - const AScaleDataType* p_a_scale_grid, - const BDataType* p_b_grid, - const BScaleDataType* p_b_scale_grid, - CDataType* p_c_grid, - void* p_shared, - const Problem& problem) - { - const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( - problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); - const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); - - // A Scale grid - const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( - make_tuple(problem.M, math::integer_divide_ceil(problem.K, ScaleBlockSize)), - make_tuple(problem.StrideScaleA, 1)); - - // B Scale grid transposed - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)), - make_tuple(problem.StrideScaleB, 1)); - - Run(p_a_grid, - p_a_scale_grid, - p_b_grid, - p_b_scale_grid, - p_c_grid, - p_shared, - problem, - a_grid_desc_ak0_m_ak1, - a_scale_grid_desc_am_ak, - b_grid_desc_bpreshuffled, - b_scale_grid_desc_bn_ak, - c_grid_desc_mblock_mperblock_nblock_nperblock); - } - - template - __device__ static void Run_2Lds(const ADataType* p_a_grid, - const AScaleDataType* p_a_scale_grid, - const BDataType* p_b_grid, - const BScaleDataType* p_b_scale_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, - const Problem& problem, - const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, - const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak, - const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled, - const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - // A Scale buffer - const auto a_scale_grid_buf = make_dynamic_buffer( - p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); - - // B Scale buffer - const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); - - const AElementwiseOperation a_element_op{}; - // const BElementwiseOperation b_element_op{}; - const CElementwiseOperation c_element_op{}; - - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) - { - return; - } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - // Thread-wise copy - // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack - auto b_block_buf = make_static_buffer( - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - - auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< - BDataType, - BDataType, - decltype(b_grid_desc_bpreshuffled), - decltype(b_block_desc_bk0_n_bk1), - Sequence{}, I1, Number{}, Number{}>, - Sequence<1, 2, 0, 3>, - 3, - BBlockTransferSrcScalarPerVector, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_grid_desc_bpreshuffled, - make_multi_index(n_block_data_idx_on_grid, - get_warp_local_1d_id() % NWave, - 0, - KPack / KGroup * (get_thread_local_1d_id() % warpSize))); - - // LDS allocation for A and B: be careful of alignment - constexpr auto max_lds_align = AK1Number; - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize, max_lds_align); - - auto a_block_buf_ping = make_dynamic_buffer( - static_cast(p_shared_0), - a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize); - - auto a_block_buf_pong = make_dynamic_buffer( - static_cast(p_shared_1), - a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize); - - auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); - - // Blockwise GEMM pipeline - static_assert(std::is_default_constructible_v); - auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; - auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); - - const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); - const auto waveId_m = wave_idx[I0]; - const auto waveId_n = wave_idx[I1]; - - static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; - - auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / - mfma.selected_mfma.num_threads_per_blk; - - auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl; - - auto a_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - 1, // SrcScalarPerVector - 1, // SrcScalarStrideInVector - true>( - a_scale_grid_desc_am_ak, - make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, thread_offset_k)); - - auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl; - - auto b_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - 1, // SrcScalarPerVector - 1, - true>( - b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k)); - - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_bufs, - a_block_slice_copy_step, - b_grid_desc_bpreshuffled, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - c_thread_buf, - a_scale_grid_desc_am_ak, - a_scale_thread_copy, - a_scale_grid_buf, - b_scale_grid_desc_bn_ak, - b_scale_thread_copy, - b_scale_grid_buf, - num_k_block_main_loop); - - // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } - } - - template - __device__ static void Run_2Lds(const ADataType* p_a_grid, - const AScaleDataType* p_a_scale_grid, - const BDataType* p_b_grid, - const BScaleDataType* p_b_scale_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, - const Problem& problem) - { - const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( - problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); - const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); - - // A Scale grid - const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( - make_tuple(problem.M, math::integer_divide_ceil(problem.K, ScaleBlockSize)), - make_tuple(problem.StrideScaleA, 1)); - - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)), - make_tuple(problem.StrideScaleB, 1)); - - Run_2Lds(p_a_grid, - p_a_scale_grid, - p_b_grid, - p_b_scale_grid, - p_c_grid, - p_shared_0, - p_shared_1, - problem, - a_grid_desc_ak0_m_ak1, - a_scale_grid_desc_am_ak, - b_grid_desc_bpreshuffled, - b_scale_grid_desc_bn_ak, - c_grid_desc_mblock_mperblock_nblock_nperblock); - } -}; - -} // namespace ck