diff --git a/CMakeLists.txt b/CMakeLists.txt index 20365a6130..1fe1bc91d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,17 +196,20 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") add_definitions(-DCK_USE_XDL) set(CK_USE_XDL "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") message("Enabling FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx95") + add_definitions(-DCK_USE_AMD_MFMA_GFX950) +endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") message("Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_OCP_FP8) set(CK_USE_OCP_FP8 "ON") endif() @@ -214,6 +217,10 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx add_definitions(-DCK_USE_FNUZ_FP8) set(CK_USE_FNUZ_FP8 "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") + add_definitions(-DCK_USE_NATIVE_MX_SUPPORT) + set(CK_USE_NATIVE_MX_SUPPORT "ON") +endif() option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) diff --git a/Jenkinsfile b/Jenkinsfile index 835b7e724f..51d492047e 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -722,9 +722,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM pipeline { agent none - triggers { - parameterizedCron(CRON_SETTINGS) - } options { parallelsAlwaysFailFast() } diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index ce5834d1e2..9e2012bf8a 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -56,7 +56,7 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() - if (GPU_TARGETS MATCHES "gfx12") + if (GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_OCP_FP8) set(CK_USE_OCP_FP8 "ON") endif() diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 77f15a213c..97ac21eba5 100755 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -61,7 +61,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 07d51855d6..414683ffdf 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -31,9 +31,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>; -// // clang-format on -// clang-format off using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index be47665a26..aa9367cdcf 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -16,7 +16,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index 94ed129dc0..018b57f82c 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc index e3370b880b..ce42a20be7 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -32,6 +32,56 @@ using BiasLayout = typename LayoutSettingSelector::BiasLayout; template using ResidualLayout = typename LayoutSettingSelector::ResidualLayout; +#if defined(CK_USE_AMD_MFMA_GFX950) +template +using DeviceConvFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InputLayout, + WeightLayout, + ck::Tuple, ResidualLayout>, + OutputLayout, + InKernelDataType, + WeiKernelDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutKernelDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 16, 1, 16>, + 4>; +#else // defined(CK_USE_AMD_MFMA_GFX950) template using DeviceConvFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< @@ -80,6 +130,7 @@ using DeviceConvFwdInstance = 1, S<1, 16, 1, 16>, 4>; +#endif // defined(CK_USE_AMD_MFMA_GFX950) template using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" + +using ScaleDataType = ck::e8m0_bexp_t; + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct ExecutionConfig final +{ + int do_verification = 1; // (0=no, 1=CPU) + int init_method = 2; // (0=no init, 1=integer value, 2=decimal value) + bool time_kernel = false; // (0=no, 1=yes) + int verbosity = 0; // (0=no info, 1=verbose info) +}; + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; +}; + +bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.verbosity = std::stoi(argv[4]); + } + else if(argc == 11) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.verbosity = std::stoi(argv[4]); + + problem_size.M = std::stoi(argv[5]); + problem_size.N = std::stoi(argv[6]); + problem_size.K = std::stoi(argv[7]); + + problem_size.StrideA = std::stoi(argv[8]); + problem_size.StrideB = std::stoi(argv[9]); + problem_size.StrideC = std::stoi(argv[10]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=CPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl + << "arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC" << std::endl; + return false; + } + + return true; +} + +template +bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using ELayout = CLayout; + using DsLayout = ck::Tuple<>; + using DsDataType = ck::Tuple<>; + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = CElementWiseOp; + + static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; + static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +#if 1 + // XXX: These parameters should not exist in MX-native GEMM kernel + static constexpr ck::index_t Scale_Block_M = 128; + static constexpr ck::index_t Scale_Block_N = 128; +#endif + static constexpr ck::index_t Scale_Block_K = MXVectorSize; + + // XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize MX-specific MFMA + // instructions. + // + // XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize device-optimized + // scaled type convert functions. + // + // XXX: In DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3, KPerBlock is expected to be equal to + // ScaleBlockK (aka MXVectorSize). + // Additionally, the following is also expected: + // static_assert(ScaleBlockM % MPerBlock == 0); + // static_assert(ScaleBlockN % NPerBlock == 0); + // In MX-native GEMM kernel these requirements should be relaxed. + // + // XXX: It appears, by default we are using mfma_f32_16x16x4xf32 + // MfmaSelector::selected_mfma.k_per_blk = + // MfmaSelector::selected_mfma.k_per_blk = mfma_f32_16x16x4xf32 + // XXX: GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 assumes scale type is float + + // clang-format off + using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 + // ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB| + // ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | | + // ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | | + // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, XDataType, BDataType, XDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPSched, BlkGemmPVer, float, float, float, float>; + // clang-format on + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + + auto f_host_tensor_descriptor = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1}); + } + else + { + return HostTensorDescriptor({row, col}, {1, stride}); + } + }; + + auto f_get_default_stride = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + if(K % Scale_Block_K != 0) + { + throw std::runtime_error("wrong! K must be multiple of Scale_Block_K (16 or 32)"); + }; + + auto Scale_Stride_AM = f_get_default_stride(M, K / Scale_Block_K, StrideA, ALayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / Scale_Block_K, N, StrideB, BLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor a_m_k_scale( + f_host_tensor_descriptor(M, K / Scale_Block_K, Scale_Stride_AM, ALayout{})); // scales for A + Tensor b_k_n_scale( + f_host_tensor_descriptor(K / Scale_Block_K, N, Scale_Stride_BN, BLayout{})); // scales for B + + Tensor c_m_n_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification + Tensor c_m_n_device_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // device result downloaded to host + + if(config.verbosity >= 0) + { + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; + std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl; + } + + switch(config.init_method) + { + case 0: + if(config.verbosity > 0) + { + std::cout << "NOTE: No input data initialization." << std::endl; + } + break; + case 1: + case 2: + ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(0.5f)}(a_m_k_scale); + ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n); + ck::utils::FillConstant{ck::type_convert(2.0f)}(b_k_n_scale); + if(config.verbosity > 0) + { + std::cout << "Init A = {1}" << std::endl; + std::cout << "Init A scale = {0.5}" << std::endl; + std::cout << "Init B = {1}" << std::endl; + std::cout << "Init B scale = {2.0}" << std::endl; + std::cout << "Expect C = {K}" << std::endl; + } + break; + + default: + if(config.verbosity > 0) + { + std::cout << "NOTE: No input data initialization." << std::endl; + } + } + + if(config.verbosity > 0) + std::cout << "Device memory allocation..." << std::endl; + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + if(config.verbosity > 0) + std::cout << "Upload data to device..." << std::endl; + a_device_buf.ToDevice(a_m_k.mData.data()); + a_scale_device_buf.ToDevice(a_m_k_scale.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + b_scale_device_buf.ToDevice(b_k_n_scale.mData.data()); + if(config.verbosity > 0) + std::cout << "Done." << std::endl; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideC, + a_scale_device_buf.GetDeviceBuffer(), + b_scale_device_buf.GetDeviceBuffer(), + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong!\n" + "Provided combination of compilation and runtime parameters is " + "not consistent with the supported device_gemm arguments."); + } + + if(config.verbosity > 0) + std::cout << "Computing GEMM on device..." << std::endl; + float ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50}); + + bool res_verified = true; + if(config.do_verification > 0) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + if(config.verbosity > 0) + { + std::cout << "Done." << std::endl; + std::cout << "Computing GEMM on host..." << std::endl; + } + + Tensor c({M, N}); + Tensor a({M, K}); + Tensor b({K, N}); + + for(int m = 0; m < M; m++) + { + for(int k = 0; k < K; k++) + { + a(m, k) = ck::type_convert(a_m_k(m, k)) * + ck::type_convert(a_m_k_scale(m, k / Scale_Block_K)); + } + } + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + b(k, n) = ck::type_convert(b_k_n(k, n)) * + ck::type_convert(b_k_n_scale(k / Scale_Block_K, n)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a, b, c, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + if(config.verbosity > 0) + { + std::cout << "Done." << std::endl; + std::cout << "Comparing results..." << std::endl; + } + + if(config.init_method == 1) + { + res_verified = + res_verified && std::abs(static_cast(K) - c_m_n_device_result(0, 0)) <= 0.0f; + std::cout << "Expected vs Computed: " << 1.0f * K << " vs " << c_m_n_device_result(0, 0) + << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl; + } + + res_verified = res_verified && + ck::utils::check_err(c_m_n_device_result, c, "Error: Incorrect results!"); + + if(config.verbosity > 0 && res_verified) + std::cout << "Done." << std::endl; + } + else + { + if(config.verbosity > 0) + std::cout << "Done." << std::endl; + } + + if(config.time_kernel) + { + std::size_t flop = std::size_t(2) * M * N * K + M * K + K * N; // GEMM + A scale + B scale + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + + sizeof(XDataType) * (M * K + K * N) / Scale_Block_K; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + + return res_verified; +} + +template +bool run_mx_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && + run_mx_gemm(problem_size, config); +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8.cpp new file mode 100644 index 0000000000..d2e21698ec --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +#if 1 +// XXX: MX-native GEMM kernel will work with e8m0_bexp_t scale type +using XDataType = float; +#else +using XDataType = ck::e8m0_bexp_t; +#endif +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t mx_vector_size = 128; // scaling block size + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index f26d738625..bcb62df625 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -23,34 +23,34 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") set(result 1) if(DEFINED DTYPES) - foreach(source IN LISTS FILE_NAME) - set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) - set(test 1) - endif() - if(test EQUAL 1) - message("removing example source file ${source} ") - list(REMOVE_ITEM FILE_NAME "${source}") - endif() - endforeach() + foreach(source IN LISTS FILE_NAME) + set(test 0) + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message("removing example source file ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() endif() set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) @@ -83,6 +83,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any microscaling examples if gfx950 target is not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") + message("removing microscaling example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #Do not build any FP8 examples if CK_ENABLE_FP8 not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8") @@ -102,7 +109,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) if(FILE_NAME MATCHES "_xdl") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) + elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950 + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -195,7 +204,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) if(FILE_NAME MATCHES "_xdl") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 66f094557c..1ec0c6bc23 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -55,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // define general macros for various architectures #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) + defined(__gfx942__) || defined(__gfx950__) #define __gfx9__ #endif -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__) #define __gfx94__ #endif #if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) @@ -163,6 +163,12 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // set rounding to nearest even as default for f8 conversions #define CK_USE_SR_F8_CONVERSION 0 +// set rounding to nearest even as default for f6 conversions +#define CK_USE_SR_F6_CONVERSION 0 + +// set rounding to nearest even as default for f4 conversions +#define CK_USE_SR_F4_CONVERSION 0 + // shuffle pk_i4 values during conversion to optimize number of binary // operations #define CK_USE_PK4_LAYOUT_SHUFFLE 1 diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 3a590c676f..994e60025d 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -131,6 +131,10 @@ #cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@ #endif +#ifndef CK_USE_NATIVE_MX_SUPPORT +#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@ +#endif + // clang-format on #endif // CK_CONFIG_H_IN diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index f5c4b43ad2..05dc491af7 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -55,20 +55,21 @@ inline bool is_xdl_supported() { return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || - ck::get_device_name() == "gfx942"; + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; } inline bool is_lds_direct_load_supported() { // Check if direct loads from global memory to LDS are supported. return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || - ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; + ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942" || + ck::get_device_name() == "gfx950"; } inline bool is_bf16_atomic_supported() { return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || - ck::get_device_name() == "gfx942"; + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; } inline bool is_gfx101_supported() diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 08bfefb87f..d33ecaeef8 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -26,6 +26,7 @@ namespace utils { template double get_relative_threshold(const int number_of_accumulations = 1) { + using F4 = ck::f4_t; using F8 = ck::f8_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) using I8 = int8_t; using I32 = int32_t; - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; if constexpr(is_same_v || is_same_v || @@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) compute_error = std::pow(2, -NumericUtils::mant) * 0.5; } - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled OutDataType for setting up the relative threshold!"); double output_error = 0; if constexpr(is_same_v || is_same_v || @@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1) } double midway_error = std::max(compute_error, output_error); - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled AccDataType for setting up the relative threshold!"); double acc_error = 0; if constexpr(is_same_v || is_same_v || @@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) template double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) { + using F4 = ck::f4_t; using F8 = ck::f8_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of using I8 = int8_t; using I32 = int32_t; - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); double compute_error = 0; @@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of compute_error = std::pow(2, expo - NumericUtils::mant) * 0.5; } - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled OutDataType for setting up the absolute threshold!"); double output_error = 0; if constexpr(is_same_v || is_same_v || @@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of } double midway_error = std::max(compute_error, output_error); - static_assert(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v, + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, "Warning: Unhandled AccDataType for setting up the absolute threshold!"); double acc_error = 0; if constexpr(is_same_v || is_same_v || @@ -450,5 +452,54 @@ check_err(const Range& out, return res; } +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, f4_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 0.5, + double atol = 0.5) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err + << " number of errors: " << err_count << std::endl; + } + return res; +} + } // namespace utils } // namespace ck diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index 6a90523c33..274051da83 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -69,6 +69,18 @@ struct GeneratorTensor_1 }; #endif +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f4_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + template <> struct GeneratorTensor_1 { @@ -183,6 +195,20 @@ struct GeneratorTensor_2 }; #endif +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f4_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; + template struct GeneratorTensor_3 { @@ -253,6 +279,23 @@ struct GeneratorTensor_3 }; #endif +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f4_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; + template struct GeneratorTensor_4 { diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index d9c4e22049..00518b369f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -94,8 +94,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index 64aa398d53..d53fbca4ea 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -56,8 +56,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index d06eab1264..25a9d7f96d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -74,8 +74,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index e950169ccf..985752796b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -60,8 +60,7 @@ __global__ void const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -108,7 +107,7 @@ __global__ void ignore = block_2_ctile_map; ignore = batch_count; ignore = compute_base_ptr_of_batch; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index d6b92bc97a..630f143260 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -83,8 +83,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 6ab1669e30..f6c228fb7b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -68,8 +68,7 @@ __global__ void const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index 34b1d503af..30ae72a63e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -59,8 +59,7 @@ __global__ void const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index e178b8f525..2662e5c360 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -67,8 +67,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -127,7 +126,7 @@ __global__ void ignore = batch_count; ignore = compute_base_ptr_of_batch; ignore = c0_matrix_mask; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 9af1a44781..bfbcebd7c8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -62,8 +62,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -112,7 +111,7 @@ __global__ void ignore = batch_count; ignore = compute_base_ptr_of_batch; ignore = c0_matrix_mask; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // Computes C = A * B0 * B1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index 6be2ffbdd7..494524b6f0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -52,8 +52,7 @@ __global__ void #endif kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index dae16612cc..df5922a04f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -55,8 +55,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_as_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 6e69213513..8aa20f7ad4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -55,8 +55,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -97,7 +96,7 @@ __global__ void ignore = b_element_op; ignore = c_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } // specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp index 811f1ae939..b9467ac194 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp @@ -50,9 +50,8 @@ __global__ void const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ - defined(__gfx12__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx9__) || \ + defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp index eaafd7d5c5..47fb630ea9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -63,8 +63,7 @@ __global__ void const Block2ETileMap block_2_etile_map, index_t NRaw) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; GridwiseGemmWelford::template Run( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index bb2db930c8..c048e7249c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -60,8 +60,7 @@ __global__ void const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 77ed9625c5..e6466a487b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -52,8 +52,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp index 77e968bba5..2554ffea46 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp @@ -47,8 +47,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index cc022b89c5..1cf58fec25 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -37,8 +37,7 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index c8c58d5d85..99bd3be15d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -87,8 +87,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfN compute_ptr_offset_of_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index a7df1c9d57..57c4b1a5cf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -60,8 +60,7 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -103,7 +102,7 @@ __global__ void compute_ptr_offset_of_batch.GetAPtrOffset(0); compute_ptr_offset_of_batch.GetBPtrOffset(0); compute_ptr_offset_of_batch.GetCPtrOffset(0); -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template ::selected_mfma.group_size; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index 42f7c2a33f..355e0130f1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); - constexpr index_t Gemm1KPack = math::max( - math::lcm( - MfmaSelector::selected_mfma.group_size, - B1K1), - MfmaSelector::selected_mfma.k_per_blk); + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; auto blockwise_gemm1 = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index bc76d4cc4f..44a488c5dd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -773,6 +773,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index afb2ad2e76..7d2dfab15f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -628,6 +628,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 24fac91e22..4f20487b9b 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -37,7 +37,17 @@ enum struct MfmaInstr mfma_f32_32x32x16f8bf8, mfma_f32_16x16x32f8bf8, mfma_f32_32x32x16bf8f8, - mfma_f32_16x16x32bf8f8 + mfma_f32_16x16x32bf8f8, + mfma_f32_32x32x16f16, + mfma_f32_16x16x32f16, + mfma_f32_32x32x16bf16, + mfma_f32_16x16x32bf16, + mfma_i32_32x32x32i8, + mfma_i32_16x16x64i8, + mfma_f32_32x32x64f8f6f4, + mfma_f32_16x16x128f8f6f4, + mfma_scale_f32_32x32x64f8f6f4, + mfma_scale_f32_16x16x128f8f6f4 }; template @@ -198,6 +208,50 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f16::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -264,6 +318,28 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf16::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -286,6 +362,28 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf16::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -440,6 +538,50 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x32i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_16x16x64i8::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { @@ -638,16 +780,115 @@ struct mfma_type } }; +// TODO: fix mfma...f8f6f4 instructions +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 32; // n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x64f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 16; // == n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x128f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 32; // n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_scale_f32_32x32x64f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 16; // == n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_scale_f32_16x16x128f8f6f4::Run(a, b, reg_c); + } +}; + template + typename additional_type = base_type, + bool is_single_rate_mfma = false> struct MfmaSelector { template + typename additional_type_ = base_type_, + bool is_single_rate_mfma_ = false> static constexpr auto GetMfma(); template <> @@ -711,13 +952,32 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16f16; +#else + return MfmaInstr::mfma_f32_32x32x8f16; +#endif + } + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x8f16; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32f16; +#else + return MfmaInstr::mfma_f32_16x16x16f16; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x16f16; } @@ -741,7 +1001,19 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16bf16; +#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_32x32x8bf16_1k; +#else + return MfmaInstr::mfma_f32_32x32x4bf16; +#endif + } + + template <> + constexpr auto GetMfma() { #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_32x32x8bf16_1k; @@ -751,7 +1023,19 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32bf16; +#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_16x16x16bf16_1k; +#else + return MfmaInstr::mfma_f32_16x16x8bf16; +#endif + } + + template <> + constexpr auto GetMfma() { #if defined(CK_USE_AMD_MFMA_BF16_1K_OP) return MfmaInstr::mfma_f32_16x16x16bf16_1k; @@ -760,7 +1044,18 @@ struct MfmaSelector #endif } -#if defined(CK_USE_AMD_MFMA_GFX940) +#if defined(__gfx950__) + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_32x32x32i8; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_i32_16x16x64i8; + } +#elif defined(__gfx942__) template <> constexpr auto GetMfma() { @@ -832,8 +1127,8 @@ struct MfmaSelector return MfmaInstr::mfma_f32_16x16x32bf8f8; } - static constexpr auto selected_mfma = - mfma_type()>{}; + static constexpr auto selected_mfma = mfma_type< + GetMfma()>{}; __host__ __device__ constexpr MfmaSelector() { @@ -1135,7 +1430,13 @@ struct XdlopsGemm return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; } - static constexpr auto mfma = MfmaSelector{}; + // Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942- + static constexpr auto + mfma = MfmaSelector < base_type, + MPerXdlops, NPerXdlops, additional_type, + ((is_same::value || is_same::value) && KPack <= 4) + ? true + : false > {}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 534a01e083..328e37d009 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -581,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ tmp.template AsType()[i]); }); } -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) else if constexpr(is_same::value) { vector_type tmp{src_thread_data}; diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index b4838277f1..42b784d303 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -20,39 +20,25 @@ #define CK_USE_OCP_FP8 0 #endif -namespace { -// https://en.cppreference.com/w/cpp/types/conditional -template -struct conditional -{ - using type = T; -}; -template -struct conditional -{ - using type = F; -}; -} // namespace - -namespace ck { - -using f8_fnuz_t = _BitInt(8); -using bf8_fnuz_t = unsigned _BitInt(8); - #if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \ - defined(__gfx1201__)) && \ + defined(__gfx1201__) || defined(__gfx950__)) && \ __HIP_DEVICE_COMPILE__ #define CK_FP8_CVT_FAST_PATH 1 #else #define CK_FP8_CVT_FAST_PATH 0 #endif -#if(defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__ +#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__ #define CK_OCP_FP8_CVT_FAST_PATH 1 #else #define CK_OCP_FP8_CVT_FAST_PATH 0 #endif +namespace ck { + +using f8_fnuz_t = _BitInt(8); +using bf8_fnuz_t = unsigned _BitInt(8); + typedef unsigned char fp8_storage_t; /** @@ -207,10 +193,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) } } - typename conditional< + typename std::conditional< sizeof(T) == 2, unsigned short int, - typename conditional::type>::type retval; + typename std::conditional::type>::type + retval; if constexpr(we == 5 && is_half && !is_fnuz) { @@ -303,7 +290,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v) return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false); } } - #endif } // namespace fp8_impl @@ -378,7 +364,7 @@ struct bf8_ocp_t __host__ explicit operator float() const #endif { -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) return fp8_impl::cast_to_f32_from_f8(this->data); #else return fp8_impl::cast_from_f8( @@ -392,7 +378,7 @@ struct bf8_ocp_t __host__ explicit operator _Float16() const #endif { -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8(this->data)); #else return fp8_impl::cast_from_f8<_Float16, wm, we, false>( @@ -553,10 +539,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); - using T_bitwise = typename conditional< + using T_bitwise = typename std::conditional< sizeof(T) == 2, unsigned short int, - typename conditional::type>::type; + typename std::conditional::type>::type; T_bitwise x_bitwise = bit_cast(_x); unsigned long long x{x_bitwise}; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 5a7030cca7..b125e3adf6 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -5,7 +5,7 @@ namespace ck { // Define the common macro for MI300 models -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__) #define __gfx94__ #endif @@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64> } }; +template +struct intrin_mfma_f32_32x32x16f16; + +template <> +struct intrin_mfma_f32_32x32x16f16<32, 32> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_16x16x32f16; + +template <> +struct intrin_mfma_f32_16x16x32f16<16, 16> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + template struct intrin_mfma_f32_32x32x8f16; @@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64> }; // bfp16 +template +struct intrin_mfma_f32_32x32x16bf16; + +template <> +struct intrin_mfma_f32_32x32x16bf16<32, 32> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_16x16x32bf16; + +template <> +struct intrin_mfma_f32_16x16x32bf16<16, 16> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + template struct intrin_mfma_f32_32x32x8bf16_1k; @@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> } }; +template +struct intrin_mfma_i32_32x32x32i8; + +template <> +struct intrin_mfma_i32_32x32x32i8<32, 32> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_i32_16x16x64i8; + +template <> +struct intrin_mfma_i32_16x16x64i8<16, 16> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + template struct intrin_mfma_i32_32x32x16i8; @@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> } }; +template +struct intrin_mfma_f32_32x32x64f8f6f4; + +/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6, +/// and f4 data types. +/// +/// @note Calls scaled version of the instruction as the original instruction is not supported in +/// the backend. That is the intended use. There is a backend optimization to select the unscaled +/// operation if the scale is 0. +template <> +struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> +{ + template + __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_scale_f32_32x32x64f8f6f4; + +template <> +struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> +{ + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t scale_a, + const f8x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, // { OPSEL_HI[0], OPSEL[0] }? + scale_a, + 0, // { OPSEL_HI[1], OPSEL[1] }? + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_scale_f32_16x16x128f8f6f4; + +template <> +struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> +{ + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t scale_a, + const f8x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, // { OPSEL_HI[0], OPSEL[0] }? + scale_a, + 0, // { OPSEL_HI[1], OPSEL[1] }? + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x128f8f6f4; + +/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4 +/// data types. +/// +/// @note Calls scaled version of the instruction as the original instruction is not supported in +/// the backend. That is the intended use. There is a backend optimization to select the unscaled +/// operation if the scale is 0. +template <> +struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> +{ + template + __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + template struct intrin_mfma_f32_32x32x16f8f8; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 882d661331..f90fcf6791 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/e8m0.hpp" #include "ck/utility/statically_indexed_array.hpp" #ifdef CK_CODE_GEN_RTC using int8_t = signed char; @@ -23,6 +24,296 @@ using std::byte; using bhalf_t = ushort; using half_t = _Float16; using int4_t = _BitInt(4); +using f4_t = unsigned _BitInt(4); +using f6_t = _BitInt(6); // e2m3 format +using bf6_t = unsigned _BitInt(6); // e3m2 format + +struct f4x2_pk_t +{ + using type = uint8_t; + type data; + f4x2_pk_t() : data{type{}} {} + f4x2_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline type unpack(Number) const + { + static_assert(I < 2, "Index is out of range."); + if constexpr(I == 0) + return data & 0b00001111; + else + return (data >> 4); + } + + __host__ __device__ inline type pack(const type x0, const type x1) + { + return (x1 << 4) | (x0 & 0b00001111); + } +}; + +struct f6x16_pk_t +{ + // store 16 elements of f6_t in an array of 3 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); + f6x16_pk_t() : data{type{}} {} + f6x16_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline f6_t unpack(Number) + { + static_assert(I < 16, "Index out of range for 16 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 16 f6_t values, place its 6 bits in the correct position + ck::static_for<0, 16, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +struct f6x32_pk_t +{ + // store 32 elements of f6_t in an array of 6 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); + f6x32_pk_t() : data{type{}} {} + f6x32_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline f6_t unpack(Number) + { + static_assert(I < 32, "Index out of range for 32 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 32 f6_t values, place its 6 bits in the correct position + ck::static_for<0, 32, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +struct bf6x16_pk_t +{ + // store 16 elements of bf6_t in an array of 3 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); + bf6x16_pk_t() : data{type{}} {} + bf6x16_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline bf6_t unpack(Number) + { + static_assert(I < 16, "Index out of range for 16 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 16 bf6_t values, place its 6 bits in the correct position + ck::static_for<0, 16, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 3; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; + +struct bf6x32_pk_t +{ + // store 32 elements of bf6_t in an array of 6 uint32_t + using element_type = uint32_t; + using type = StaticallyIndexedArray_v2; + type data; + typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); + bf6x32_pk_t() : data{type{}} {} + bf6x32_pk_t(type init) : data{init} {} + + template + __host__ __device__ inline bf6_t unpack(Number) + { + static_assert(I < 32, "Index out of range for 32 f6_t elements."); + + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = I * num_bits_elem; + constexpr int arr_idx = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + uint32_t bits = data.At(Number{}) >> bit_offset; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline type pack(const test_vec_t& x) + { + type packed{}; + + // for each of the 32 bf6_t values, place its 6 bits in the correct position + ck::static_for<0, 32, 1>{}([&](auto i) { + uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; + constexpr int num_bits_elem = 6; + constexpr int num_bits_vec_elem = 32; + constexpr int vector_size = 6; + constexpr int bit_pos = i * num_bits_elem; + constexpr int arr_index = bit_pos / num_bits_vec_elem; + constexpr int bit_offset = bit_pos % num_bits_vec_elem; + constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = packed.At(Number{}); + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + packed.At(Number{}) = old_value; + + // if it crosses into the next block, shift the remainder + if constexpr(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = packed.At(Number{}); + next_value |= (bits >> (num_bits_elem - overhang)); + packed.At(Number{}) = next_value; + } + }); + + return packed; + } +}; // custom data type - pack int4 data struct pk_i4_t @@ -40,14 +331,15 @@ inline constexpr auto next_pow2(uint32_t x) } // native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, -// native types: bool +// native types: bool, f4_t, f6_t, bf6_t template inline constexpr bool is_native_type() { return is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value; + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value; } // vector_type @@ -1370,12 +1662,37 @@ struct nnvb_data_t_selector { using type = f8_ocp_t::data_type; }; + template <> struct nnvb_data_t_selector { using type = bf8_ocp_t::data_type; }; +template <> +struct nnvb_data_t_selector +{ + using type = f6x16_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f6x32_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x16_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x32_pk_t::type; +}; + template <> struct nnvb_data_t_selector { @@ -1482,6 +1799,63 @@ struct non_native_vector_base< } }; +// implementation for f6x16 and f6x32 +template +struct non_native_vector_base> +{ + using data_t = + typename nnvb_data_t_selector::type; // select data_t based on declared base type + using element_t = typename T::element_type; // select element_t based on declared element type + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + static constexpr size_t size_factor = + sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6 + using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); + using type = non_native_vector_base; + + union alignas(next_pow2(N * sizeof(T))) + { + data_v dN; // storage vector; + StaticallyIndexedArray dxN; + StaticallyIndexedArray dTxN; + StaticallyIndexedArray dNx1; + } data_; + + __host__ __device__ constexpr non_native_vector_base(data_t a) + : data_{data_v(a.At(Number<0>{}))} + { + } + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) + { + } + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const + { + if constexpr(N == 1) + { + return data_.dxN[Number<0>{}]; + } + else + { + return data_.dxN; // XXX this should cause an error + } + } + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) + { + return data_.dTxN[Number<0>{}]; + } + else + { + return data_.dTxN; // XXX this should cause an error + } + } +}; + template struct scalar_type>; @@ -2217,6 +2591,22 @@ using uint8x16_t = typename vector_type::type; using uint8x32_t = typename vector_type::type; using uint8x64_t = typename vector_type::type; +// f4 +using f4x2_t = typename vector_type::type; +using f4x4_t = typename vector_type::type; +using f4x8_t = typename vector_type::type; +using f4x16_t = typename vector_type::type; +using f4x32_t = typename vector_type::type; +using f4x64_t = typename vector_type::type; + +// f6 +using f6x16_t = typename vector_type::type; +using f6x32_t = typename vector_type::type; + +// bf6 +using bf6x16_t = typename vector_type::type; +using bf6x32_t = typename vector_type::type; + // pack int4 using pk_i4x2_t = typename vector_type::type; using pk_i4x4_t = typename vector_type::type; @@ -2571,6 +2961,118 @@ struct NumericLimits }; #endif +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 + static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 + static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 + static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 + static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 + + static constexpr float data_max_normal_number = 6; + static constexpr float data_min_subnormal_number = 0.5; + + __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } + __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } + __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } + __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } + __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 + + static constexpr float data_max_normal_number = 7.5; + static constexpr float data_min_subnormal_number = 0.125; + + __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Lowest() + { + return f6_t(binary_lowest_normal & 0b111111); + } + __host__ __device__ static constexpr f6_t MinSubnorm() + { + return f6_t(binary_min_subnorm & 0b111111); + } + __host__ __device__ static constexpr f6_t MaxSubnorm() + { + return f6_t(binary_max_subnorm & 0b111111); + } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 + + static constexpr float data_max_normal_number = 28; + static constexpr float data_min_subnormal_number = 0.0625; + + __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } + __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } + __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } + __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } + __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 + static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 + static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 + static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 + static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 + static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 + static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 + static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 + + __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } + __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } + __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_135() + { + return e8m0_bexp_t(binary_135); + } + __host__ __device__ static constexpr e8m0_bexp_t Binary_142() + { + return e8m0_bexp_t(binary_142); + } +}; + template struct NumericUtils { @@ -2590,6 +3092,7 @@ struct NumericUtils static constexpr uint32_t NegInf = 0xFF800000; static constexpr uint32_t NaN = 0x7F800001; static constexpr uint32_t Neg0 = 0x80000000; + static constexpr bool has_inf = true; using bitwise_type = uint32_t; }; @@ -2607,9 +3110,19 @@ struct NumericUtils static constexpr uint32_t NegInf = 0xFC00; static constexpr uint32_t NaN = 0x7C01; static constexpr uint32_t Neg0 = 0x8000; + static constexpr bool has_inf = true; using bitwise_type = uint16_t; }; +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 7; + static constexpr int bias = 128; // negative zero nan mode + // static constexpr int bias = 127; // ieee mode +}; + template <> struct NumericUtils { @@ -2617,6 +3130,7 @@ struct NumericUtils static constexpr int mant = 3; static constexpr int bias = 8; // negative zero nan mode // static constexpr int bias = 7; // ieee mode + static constexpr bool has_inf = false; }; template <> @@ -2626,6 +3140,7 @@ struct NumericUtils static constexpr int mant = 2; static constexpr int bias = 16; // negative zero nan mode // static constexpr int bias = 15; // ieee mode + static constexpr bool has_inf = false; }; template <> struct NumericUtils @@ -2644,11 +3159,109 @@ struct NumericUtils }; template <> -struct NumericUtils +struct NumericUtils +{ + static constexpr int exp = 2; + static constexpr int mant = 1; + static constexpr int bias = 1; + static constexpr uint32_t sr_shift = 10; + + static constexpr int unbiased_exp_min = 0; + static constexpr int unbiased_exp_max = 2; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 3; + + static constexpr uint8_t positive_zero_mask = 0b0000; + static constexpr uint8_t negative_zero_mask = 0b1000; + + static constexpr uint8_t one_mask = 0b0010; + static constexpr uint8_t set_sign_mask = 0b0111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b0111; + static constexpr uint8_t data_max_negative_normal_mask = 0b1111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001; + + static constexpr bool has_inf = false; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 2; + static constexpr int mant = 3; + static constexpr int bias = 1; + static constexpr uint32_t sr_shift = 12; + + static constexpr int unbiased_exp_min = 0; + static constexpr int unbiased_exp_max = 2; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 3; + + static constexpr uint8_t positive_zero_mask = 0b000000; + static constexpr uint8_t negative_zero_mask = 0b100000; + + static constexpr uint8_t set_sign_mask = 0b011111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b011111; + static constexpr uint8_t data_max_negative_normal_mask = 0b111111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b000111; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b100111; + + static constexpr bool has_inf = false; + static constexpr bool has_nan = false; + static constexpr bool has_zero = true; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 3; + static constexpr int mant = 2; + static constexpr int bias = 3; + static constexpr uint32_t sr_shift = 11; + + static constexpr int unbiased_exp_min = -2; + static constexpr int unbiased_exp_max = 4; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 7; + + static constexpr uint8_t positive_zero_mask = 0b000000; + static constexpr uint8_t negative_zero_mask = 0b100000; + + static constexpr uint8_t set_sign_mask = 0b011111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b011111; + static constexpr uint8_t data_max_negative_normal_mask = 0b111111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b000011; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b100011; + + static constexpr bool has_inf = false; + static constexpr bool has_nan = false; + static constexpr bool has_zero = true; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils { static constexpr int exp = 8; - static constexpr int mant = 7; - static constexpr int bias = 128; // negative zero nan mode - // static constexpr int bias = 127; // ieee mode + static constexpr int mant = 0; + static constexpr int bias = 127; + + static constexpr int unbiased_exp_min = -127; + static constexpr int unbiased_exp_max = 127; + static constexpr int biased_exp_min = 0; + static constexpr int biased_exp_max = 254; + + using bitwise_type = uint8_t; }; } // namespace ck diff --git a/include/ck/utility/e8m0.hpp b/include/ck/utility/e8m0.hpp new file mode 100644 index 0000000000..a692f533f8 --- /dev/null +++ b/include/ck/utility/e8m0.hpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/type.hpp" + +namespace ck { + +/** + * @brief Unsigned representation of a conventional biased Float32 exponent. + * + * bias = 127; + * + * E8M0_1 = 0b01111111; => 2^(127-127) = 1 + * E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2 + * E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8 + * E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256 + * E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768 + * E8M0_MIN = 0b00000000; => 2^-127 + * E8M0_MAX = 0b11111110; => 2^127 + * E8M0_NAN = 0b11111111; => NaN + */ +struct e8m0_bexp_t +{ + using type = uint8_t; + type data; + + constexpr static type bias = 127; + constexpr static type nan_mask = 0xFF; + + __host__ __device__ constexpr e8m0_bexp_t() : data{type{}} {} + __host__ __device__ constexpr e8m0_bexp_t(type init) : data{init} {} + __host__ __device__ constexpr e8m0_bexp_t(int init) : data{static_cast(init & nan_mask)} + { + } + __host__ __device__ explicit constexpr e8m0_bexp_t(float scale) + : data{static_cast((bit_cast(scale) & (nan_mask << 23)) >> 23)} + { + } + + __host__ __device__ explicit constexpr operator float() const + { + if(data == nan_mask || data == 0) + { + uint32_t bits = data << 1; + bits |= 1; + bits <<= 22; + return bit_cast(bits); + } + else + { + uint32_t bits = data << 23; + return bit_cast(bits); + } + } + + __host__ __device__ constexpr bool operator==(const e8m0_bexp_t& other) const + { + // strict IEEE compliance for NaN + return data == other.data && data != nan_mask; + } + + __host__ __device__ constexpr bool is_nan() const { return data == nan_mask; } +}; + +namespace utils { + +template +__host__ __device__ inline int get_exponent_value(T x); + +template <> +__host__ __device__ inline int get_exponent_value(e8m0_bexp_t x) +{ + return x.data; +} + +} // namespace utils + +} // namespace ck diff --git a/include/ck/utility/mxf4_utils.hpp b/include/ck/utility/mxf4_utils.hpp new file mode 100644 index 0000000000..15e693bd0d --- /dev/null +++ b/include/ck/utility/mxf4_utils.hpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/mxfp_utils.hpp" + +namespace ck::utils { + +template <> +__host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, + f4_t const dataBytes [[maybe_unused]]) +{ + // no need to check for data as it does not have NaN representation + return scale == NumericLimits::QuietNaN(); +} + +// no infinity representation in ocp_e2m1_mxfp4 will always return false +template <> +__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale [[maybe_unused]], + f4_t const data [[maybe_unused]]) +{ + // no inf representation for ocp_e2m1_mxfp4 + return false; +} + +template <> +__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, f4_t const data) +{ + if(is_nan(scale, data)) + return false; + + // no need to check for scale as it does not have a 0 representation + f4_t result = (data & 0b00001111) & NumericUtils::set_sign_mask; + + return result == 0b0; +} + +template <> +__host__ __device__ inline float to_float(e8m0_bexp_t const scale, f4_t const data) +{ + if(is_nan(scale, data)) + return std::numeric_limits::quiet_NaN(); + + if(is_zero(scale, data)) + return 0.0f; + + f4_t prepared_data = data & 0b00001111; + + int scale_exp = get_exponent_value(scale); + + return convert_to_float(prepared_data, scale_exp); +} + +template <> +__host__ __device__ inline f4_t sat_convert_to_type(float value) +{ + cvt t; + t.value_float = value; + uint32_t sign = t.value_bitwise >> 31; + + if(std::isnan(value)) + { + + return sign ? NumericUtils::data_max_negative_normal_mask + : NumericUtils::data_max_positive_normal_mask; + } + + if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + return sign ? NumericUtils::data_max_negative_normal_mask + : NumericUtils::data_max_positive_normal_mask; + + f4_t res = convert_to_type(value); + + if(std::abs(to_float(NumericLimits::Binary_1(), res)) < + NumericLimits::DataMinSubnorm()) + return value < 0 ? NumericUtils::negative_zero_mask + : NumericUtils::positive_zero_mask; + + return res; +} + +template <> +__host__ __device__ inline f4_t sat_convert_to_type_sr(float value, uint32_t seed) +{ + cvt t; + t.value_float = value; + uint32_t sign = t.value_bitwise >> 31; + + if(std::isnan(value)) + return sign ? NumericUtils::data_max_negative_normal_mask + : NumericUtils::data_max_positive_normal_mask; + + if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + return sign ? NumericUtils::data_max_negative_normal_mask + : NumericUtils::data_max_positive_normal_mask; + + f4_t res = convert_to_type_sr(value, seed); + + if(std::abs(to_float(NumericLimits::Binary_1(), res)) < + NumericLimits::DataMinSubnorm()) + return value < 0 ? NumericUtils::negative_zero_mask + : NumericUtils::positive_zero_mask; + + return res; +} + +} // namespace ck::utils diff --git a/include/ck/utility/mxf6_utils.hpp b/include/ck/utility/mxf6_utils.hpp new file mode 100644 index 0000000000..e3b37bedda --- /dev/null +++ b/include/ck/utility/mxf6_utils.hpp @@ -0,0 +1,325 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/mxfp_utils.hpp" + +namespace ck::utils { + +/** + * @brief Checks if an f6_t value is NaN based on the provided scale. + * + * For f6_t data, NaN cannot be represented directly. Instead, this function + * determines NaN by checking if the scale is set to a quiet NaN. + * + * @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t. + * @param dataBytes The f6_t value to check (unused in this implementation). + * @return true if the scale indicates a NaN value, false otherwise. + */ +template <> +__host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, + f6_t const dataBytes [[maybe_unused]]) +{ + // no need to check for data as it does not have NaN representation + return scale.is_nan(); +} + +/** + * @brief Checks if an bf6_t value is NaN based on the provided scale. + * + * For bf6_t data, NaN cannot be represented directly. Instead, this function + * determines NaN by checking if the scale is set to a quiet NaN. + * + * @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t. + * @param dataBytes The bf6_t value to check (unused in this implementation). + * @return true if the scale indicates a NaN value, false otherwise. + */ +template <> +__host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, + bf6_t const dataBytes [[maybe_unused]]) +{ + // no need to check for data as it does not have NaN representation + return scale.is_nan(); +} + +/** + * @brief Checks if an f6_t value is infinite. + * + * Because f6_t does not support infinite values, this function always returns false. + * + * @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t. + * @param data The f6_t value to check. + * @return Always false, as infinity is not represented in f6_t. + */ +template <> +__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale [[maybe_unused]], + f6_t const data [[maybe_unused]]) +{ + // no inf representation for fp6 + return false; +} + +/** + * @brief Checks if an bf6_t value is infinite. + * + * Because bf6_t does not support infinite values, this function always returns false. + * + * @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t. + * @param data The bf6_t value to check. + * @return Always false, as infinity is not represented in bf6_t. + */ +template <> +__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale [[maybe_unused]], + bf6_t const data [[maybe_unused]]) +{ + // no inf representation for bf6 + return false; +} + +/** + * @brief Checks whether an f6_t value is zero. + * + * If the specified f6_t is NaN, this function returns false. + * Otherwise, it masks out the sign bits and checks if the remaining bits + * are zero. + * + * @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t. + * @param data The f6_t value to check. + * @return true if the value is zero; otherwise false. + */ +template <> +__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, f6_t const data) +{ + if(is_nan(scale, data)) + return false; + + // no need to check for scale as it does not have a 0 representation + f6_t result = (data & 0b00111111) & NumericUtils::set_sign_mask; + + return result == 0b0; +} + +/** + * @brief Checks whether an bf6_t value is zero. + * + * If the specified bf6_t is NaN, this function returns false. + * Otherwise, it masks out the sign bits and checks if the remaining bits + * are zero. + * + * @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t. + * @param data The bf6_t value to check. + * @return true if the value is zero; otherwise false. + */ +template <> +__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, bf6_t const data) +{ + if(is_nan(scale, data)) + return false; + + // no need to check for scale as it does not have a 0 representation + bf6_t result = (data & 0b00111111) & NumericUtils::set_sign_mask; + + return result == 0b0; +} + +/** + * @brief Converts an f6_t value to a float based on an e8m0_bexp_t scale factor. + * + * Checks if the f6_t value is NaN or zero before performing the conversion. + * Applies the exponent from the scale to compute the final float result. + * + * @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t. + * @param data The f6_t value to convert. + * @return The converted float value. + */ +template <> +__host__ __device__ inline float to_float(e8m0_bexp_t const scale, f6_t const data) +{ + if(is_nan(scale, data)) + return std::numeric_limits::quiet_NaN(); + + if(is_zero(scale, data)) + return 0.0f; + + f6_t prepared_data = data & 0b00111111; + + int scale_exp = get_exponent_value(scale); + + return convert_to_float(prepared_data, scale_exp); +} + +/** + * @brief Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor. + * + * Checks if the bf6_t value is NaN or zero before performing the conversion. + * Applies the exponent from the scale to compute the final float result. + * + * @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t. + * @param data The bf6_t value to convert. + * @return The converted float value. + */ +template <> +__host__ __device__ inline float to_float(e8m0_bexp_t const scale, bf6_t const data) +{ + if(is_nan(scale, data)) + return std::numeric_limits::quiet_NaN(); + + if(is_zero(scale, data)) + return 0.0f; + + bf6_t prepared_data = data & 0b00111111; + + int scale_exp = get_exponent_value(scale); + + return convert_to_float(prepared_data, scale_exp); +} + +/** + * @brief Converts a float to f6_t with saturation. + * + * If the input is NaN or exceeds the representable range for f6_t, returns + * the corresponding max normal mask. Handles subnormal cases by returning + * zero with the appropriate sign. + * + * @param value The float value to be converted. + * @return The saturated f6_t value. + */ +template <> +__host__ __device__ inline f6_t sat_convert_to_type(float value) +{ + cvt t; + t.value_float = value; + uint32_t sign = t.value_bitwise >> 31; + + if(std::isnan(value)) + { + + return sign ? NumericUtils::data_max_negative_normal_mask + : NumericUtils::data_max_positive_normal_mask; + } + + if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + return sign ? NumericUtils::data_max_negative_normal_mask + : NumericUtils::data_max_positive_normal_mask; + + f6_t res = convert_to_type(value); + + if(std::abs(to_float(NumericLimits::Binary_1(), res)) < + NumericLimits::DataMinSubnorm()) + return sign ? NumericUtils::negative_zero_mask + : NumericUtils::positive_zero_mask; + + return res; +} + +/** + * @brief Converts a float to bf6_t with saturation. + * + * If the input is NaN or exceeds the representable range for bf6_t, returns + * the corresponding max normal mask. Handles subnormal cases by returning + * zero with the appropriate sign. + * + * @param value The float value to be converted. + * @return The saturated bf6_t value. + */ +template <> +__host__ __device__ inline bf6_t sat_convert_to_type(float value) +{ + cvt t; + t.value_float = value; + uint32_t sign = t.value_bitwise >> 31; + + if(std::isnan(value)) + { + + return sign ? NumericUtils::data_max_negative_normal_mask + : NumericUtils::data_max_positive_normal_mask; + } + + if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + return sign ? NumericUtils::data_max_negative_normal_mask + : NumericUtils::data_max_positive_normal_mask; + + bf6_t res = convert_to_type(value); + + if(std::abs(to_float(NumericLimits::Binary_1(), res)) < + NumericLimits::DataMinSubnorm()) + return sign ? NumericUtils::negative_zero_mask + : NumericUtils::positive_zero_mask; + + return res; +} + +/** + * @brief Converts a float to f6_t with saturation and stochastic rounding. + * + * If the input is NaN or exceeds the representable range for f6_t, returns + * the corresponding max normal mask. Handles subnormal cases by returning + * zero with the appropriate sign. + * + * @param value The float value to be converted. + * @return The saturated f6_t value. + */ +template <> +__host__ __device__ inline f6_t sat_convert_to_type_sr(float value, uint32_t seed) +{ + cvt t; + t.value_float = value; + uint32_t sign = t.value_bitwise >> 31; + + if(std::isnan(value)) + return sign ? NumericUtils::data_max_negative_normal_mask + : NumericUtils::data_max_positive_normal_mask; + + if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + return sign ? NumericUtils::data_max_negative_normal_mask + : NumericUtils::data_max_positive_normal_mask; + + f6_t res = convert_to_type_sr(value, seed); + + if(std::abs(to_float(NumericLimits::Binary_1(), res)) < + NumericLimits::DataMinSubnorm()) + return sign ? NumericUtils::negative_zero_mask + : NumericUtils::positive_zero_mask; + + return res; +} + +/** + * @brief Converts a float to f6_t with saturation and stochastic rounding. + * + * If the input is NaN or exceeds the representable range for f6_t, returns + * the corresponding max normal mask. Handles subnormal cases by returning + * zero with the appropriate sign. + * + * @param value The float value to be converted. + * @return The saturated f6_t value. + */ +template <> +__host__ __device__ inline bf6_t sat_convert_to_type_sr(float value, uint32_t seed) +{ + cvt t; + t.value_float = value; + uint32_t sign = t.value_bitwise >> 31; + + if(std::isnan(value)) + return sign ? NumericUtils::data_max_negative_normal_mask + : NumericUtils::data_max_positive_normal_mask; + + if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + return sign ? NumericUtils::data_max_negative_normal_mask + : NumericUtils::data_max_positive_normal_mask; + + bf6_t res = convert_to_type_sr(value, seed); + + if(std::abs(to_float(NumericLimits::Binary_1(), res)) < + NumericLimits::DataMinSubnorm()) + return sign ? NumericUtils::negative_zero_mask + : NumericUtils::positive_zero_mask; + + return res; +} + +} // namespace ck::utils diff --git a/include/ck/utility/mxf8_utils.hpp b/include/ck/utility/mxf8_utils.hpp new file mode 100644 index 0000000000..2dbf997f6a --- /dev/null +++ b/include/ck/utility/mxf8_utils.hpp @@ -0,0 +1,570 @@ +#include "ck/utility/data_type.hpp" +#include "ck/utility/mxfp_utils.hpp" + +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ +#define CK_MX_FP8_CVT_FAST_PATH 1 +#else +#define CK_MX_FP8_CVT_FAST_PATH 0 +#endif + +namespace ck { + +namespace fp8_impl { +#if CK_MX_FP8_CVT_FAST_PATH +template +static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v) +{ + union + { + unsigned int i32val; + unsigned char i8val[4]; + } val; + val.i8val[0] = v; + + static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP || + interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, + "Only OCP interpretations are supported"); + + if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + { + return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0); + } + else + { + return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0); + } +} + +template +static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v) +{ + const auto i16val = bit_cast(v); + + static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP || + interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, + "Only OCP interpretations are supported"); + + if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + { + return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0); + } + else + { + return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0); + } +} + +template +static __device__ fp8_storage_t cast_to_f8_from_f32_scaled(float v, + unsigned int rng = 0, + float scale = 1.0f) +{ + fp8_storage_t i8data; + union + { + float fval; + unsigned int i32val; + } val; + + union + { + uint32_t ival; + vector_type::type v2i16; + fp8_storage_t v4i8[4]; + } ret{}; + + // unsigned int ival = 0; + val.fval = v; + + if constexpr(stochastic_rounding) + { + ret.ival = + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + ? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0) + : __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0); + + i8data = ret.v4i8[0]; + } + else + { + // RNE CVT + // llvm.amdgcn.cvt.scalef32.pk.fp8.f32 + // v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel + if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + { + // If fval / scale > max fp8, returns Nan + ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16, + val.fval, + val.fval, + scale, + /*dst_lo_hi_sel*/ false); + } + else + { + // If fval / scale > max bf8, returns Inf + ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16, + val.fval, + val.fval, + scale, + /*dst_lo_hi_sel*/ false); + } + + i8data = ret.v4i8[0]; + } + return i8data; +} + +template +static __device__ fp8x2_storage_t cast_to_f8_from_f32_scaled(float2_t v, + unsigned int rng = 0, + float scale = 1.0f) +{ + + union + { + uint32_t ival; + vector_type::type v2i16; + StaticallyIndexedArray v2f8x2; + } ret{}; + + if constexpr(stochastic_rounding) + { + fp8x2_storage_t f8x2; + if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + { + ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0); + f8x2[0] = ret.v2f8x2(Number<0>{})[0]; + ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0); + f8x2[1] = ret.v2f8x2(Number<0>{})[0]; + } + else + { + ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0); + f8x2[0] = ret.v2f8x2(Number<0>{})[0]; + ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0); + f8x2[1] = ret.v2f8x2(Number<0>{})[0]; + } + return f8x2; + } + else + { + // RNE CVT + // llvm.amdgcn.cvt.scalef32.pk.fp8.f32 + // v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel + if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + { + // If fval / scale > max fp8, returns Nan + ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16, + v[0], + v[1], + scale, + /*dst_lo_hi_sel*/ false); + } + else + { + // If fval / scale > max bf8, returns Inf + ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16, + v[0], + v[1], + scale, + /*dst_lo_hi_sel*/ false); + } + + return ret.v2f8x2(Number<0>{}); + } +} + +#endif // CK_MX_FP8_CVT_FAST_PATH + +#if CK_MX_FP8_CVT_FAST_PATH +/** + * \brief convert float to @p fp8_storage_t with scaling + * + * This version is used when the fast path (MX FP8 hardware) is available + * + * \tparam interp interpretation of fp8 + * \param f float number + * \param scale scaling factor + * \return fp8_storage_t + */ +template +__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale) +{ + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; + rng = prand_generator(reinterpret_cast(&f), f); + } + return cast_to_f8_from_f32_scaled(f, rng, scale); +} + +/** + * \brief convert 2xfloat to @p 2xfp8_storage_t with scaling + * + * This version is used when the fast path (MX FP8 hardware) is available + * + * \tparam interp interpretation of fp8 + * \param f 2xfloat + * \param scale scaling factor + * \return 2xfp8_storage_t + */ +template +__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f, + float scale) +{ + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; + rng = prand_generator(reinterpret_cast(&f), f[0]); + } + return cast_to_f8_from_f32_scaled(f, rng, scale); +} + +#else + +/** + * \brief convert float to @p fp8_storage_t with scaling + * + * This version is used when the fast path (MX FP8 hardware) is not available + * + * \tparam interp interpretation of fp8 + * \param f float number + * \param scale scaling factor + * \return fp8_storage_t + */ +template +__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale) +{ + + static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP || + interp == ck_fp8_interpretation_t::CK_E5M2_OCP, + "Only OCP interpretations are supported"); + + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; + rng = prand_generator(reinterpret_cast(&f), f); + } + + if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP) + { + return cast_to_f8(f / scale, rng); + } + else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP) + { + return cast_to_f8(f / scale, rng); + } + else + { + __hip_assert(false && "FP8 type is not supported by current target device"); + return 0; + } +} + +/** + * \brief convert two float to @p 2xfp8_storage_t with scaling + * + * This version is used when the fast path (MX FP8 hardware) is not available + * + * \tparam interp interpretation of fp8 + * \param f 2xfloat + * \param scale scaling factor + * \return 2xfp8_storage_t + */ +template +__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f, + float scale) +{ + + static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP || + interp == ck_fp8_interpretation_t::CK_E5M2_OCP, + "Only OCP interpretations are supported"); + + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; + rng = prand_generator(reinterpret_cast(&f), f[0]); + } + + if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP) + { + return {cast_to_f8(f[0] / scale, rng), + cast_to_f8(f[1] / scale, rng)}; + } + else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP) + { + return {cast_to_f8(f[0] / scale, rng), + cast_to_f8(f[1] / scale, rng)}; + } + else + { + __hip_assert(false && "FP8 type is not supported by current target device"); + return 0; + } +} + +#endif // CK_MX_FP8_CVT_FAST_PATH + +} // namespace fp8_impl + +// Declare a template function for fp8 conversion using SR +template +__host__ __device__ constexpr Y mxf8_convert_sr(X x, float scale); + +// Declare a template function for fp8 conversion using RNE +template +__host__ __device__ constexpr Y mxf8_convert_rne(X x, float scale); + +// convert fp32 to fp8 with rounding to nearest even +template <> +inline __host__ __device__ f8_ocp_t mxf8_convert_rne(float x, float scale) +{ + return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled(x, scale)}; +} + +// convert fp32 to bf8 with rounding to nearest even +template <> +inline __host__ __device__ bf8_ocp_t mxf8_convert_rne(float x, float scale) +{ + return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled(x, scale)}; +} + +// convert fp32x2 to fp8x2 with rounding to nearest even +template <> +inline __host__ __device__ f8x2_ocp_t mxf8_convert_rne(float2_t x, + float scale) +{ + return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled(x, scale)}; +} + +// convert fp32x2 to bf8x2 with rounding to nearest even +template <> +inline __host__ __device__ bf8x2_ocp_t mxf8_convert_rne(float2_t x, + float scale) +{ + return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled(x, scale)}; +} + +// convert fp32x16 to fp8x16 with rounding to nearest even +template <> +inline __host__ __device__ f8x16_ocp_t mxf8_convert_rne(float16_t x, + float scale) +{ + union + { + float16_t float_1x16; + float2_t float_2x8[8]; + } in{x}; + + union + { + f8x16_ocp_t fp8_1x16; + f8x2_ocp_t fp8_2x8[8]; + } out{}; + + ck::static_for<0, 8, 1>{}( + [&](auto i) { out.fp8_2x8[i] = mxf8_convert_rne(in.float_2x8[i], scale); }); + + return out.fp8_1x16; +} + +// convert fp32x16 to bf8x16 with rounding to nearest even +template <> +inline __host__ __device__ bf8x16_ocp_t mxf8_convert_rne(float16_t x, + float scale) +{ + union + { + float16_t float_1x16; + float2_t float_2x8[8]; + } in{x}; + + union + { + bf8x16_ocp_t bf8_1x16; + bf8x2_ocp_t bf8_2x8[8]; + } out{}; + + ck::static_for<0, 8, 1>{}( + [&](auto i) { out.bf8_2x8[i] = mxf8_convert_rne(in.float_2x8[i], scale); }); + + return out.bf8_1x16; +} + +// convert fp32x32 to fp8x32 with rounding to nearest even +template <> +inline __host__ __device__ f8x32_ocp_t mxf8_convert_rne(float32_t x, + float scale) +{ + union + { + float32_t float_1x32; + float16_t float_16x2[2]; + } in{x}; + + union + { + f8x32_ocp_t fp8_1x32; + f8x16_ocp_t fp8_16x2[2]; + } out{}; + + ck::static_for<0, 2, 1>{}( + [&](auto i) { out.fp8_16x2[i] = mxf8_convert_rne(in.float_16x2[i], scale); }); + + return out.fp8_1x32; +} + +// convert fp32x32 to bf8x32 with rounding to nearest even +template <> +inline __host__ __device__ bf8x32_ocp_t mxf8_convert_rne(float32_t x, + float scale) +{ + union + { + float32_t float_1x32; + float16_t float_16x2[2]; + } in{x}; + + union + { + bf8x32_ocp_t bf8_1x32; + bf8x16_ocp_t bf8_16x2[2]; + } out{}; + + ck::static_for<0, 2, 1>{}( + [&](auto i) { out.bf8_16x2[i] = mxf8_convert_rne(in.float_16x2[i], scale); }); + + return out.bf8_1x32; +} + +// convert fp32 to fp8 with stochastic rounding +template <> +inline __host__ __device__ f8_ocp_t mxf8_convert_sr(float x, float scale) +{ + return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled(x, scale)}; +} + +// convert fp32 to bf8 with stochastic rounding +template <> +inline __host__ __device__ bf8_ocp_t mxf8_convert_sr(float x, float scale) +{ + return bf8_ocp_t{ + fp8_impl::cvt_float_to_fp8_scaled(x, scale)}; +} + +// convert fp32x2 to fp8x2 with stochastic rounding +template <> +inline __host__ __device__ f8x2_ocp_t mxf8_convert_sr(float2_t x, float scale) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_float_to_fp8_scaled(x, scale)}; +} + +// convert fp32x2 to bf8x2 with stochastic rounding +template <> +inline __host__ __device__ bf8x2_ocp_t mxf8_convert_sr(float2_t x, + float scale) +{ + return bf8x2_ocp_t{ + fp8_impl::cvt_float_to_fp8_scaled(x, scale)}; +} + +// convert fp32x16 to fp8x16 with stochastic rounding +template <> +inline __host__ __device__ f8x16_ocp_t mxf8_convert_sr(float16_t x, + float scale) +{ + union + { + float16_t float_1x16; + float2_t float_2x8[8]; + } in{x}; + + union + { + f8x16_ocp_t fp8_1x16; + f8x2_ocp_t fp8_2x8[8]; + } out{}; + + ck::static_for<0, 8, 1>{}( + [&](auto i) { out.fp8_2x8[i] = mxf8_convert_sr(in.float_2x8[i], scale); }); + + return out.fp8_1x16; +} + +// convert fp32x16 to bf8x16 with stochastic rounding +template <> +inline __host__ __device__ bf8x16_ocp_t mxf8_convert_sr(float16_t x, + float scale) +{ + union + { + float16_t float_1x16; + float2_t float_2x8[8]; + } in{x}; + + union + { + bf8x16_ocp_t bf8_1x16; + bf8x2_ocp_t bf8_2x8[8]; + } out{}; + + ck::static_for<0, 8, 1>{}( + [&](auto i) { out.bf8_2x8[i] = mxf8_convert_sr(in.float_2x8[i], scale); }); + + return out.bf8_1x16; +} + +// convert fp32x32 to fp8x32 with stochastic rounding +template <> +inline __host__ __device__ f8x32_ocp_t mxf8_convert_sr(float32_t x, + float scale) +{ + union + { + float32_t float_1x32; + float16_t float_16x2[2]; + } in{x}; + + union + { + f8x32_ocp_t fp8_1x32; + f8x16_ocp_t fp8_16x2[2]; + } out{}; + + ck::static_for<0, 2, 1>{}( + [&](auto i) { out.fp8_16x2[i] = mxf8_convert_sr(in.float_16x2[i], scale); }); + + return out.fp8_1x32; +} + +// convert fp32x32 to bf8x32 with stochastic rounding +template <> +inline __host__ __device__ bf8x32_ocp_t mxf8_convert_sr(float32_t x, + float scale) +{ + union + { + float32_t float_1x32; + float16_t float_16x2[2]; + } in{x}; + + union + { + bf8x32_ocp_t bf8_1x32; + bf8x16_ocp_t bf8_16x2[2]; + } out{}; + + ck::static_for<0, 2, 1>{}( + [&](auto i) { out.bf8_16x2[i] = mxf8_convert_sr(in.float_16x2[i], scale); }); + + return out.bf8_1x32; +} + +} // namespace ck diff --git a/include/ck/utility/mxfp_utils.hpp b/include/ck/utility/mxfp_utils.hpp new file mode 100644 index 0000000000..e23836c87f --- /dev/null +++ b/include/ck/utility/mxfp_utils.hpp @@ -0,0 +1,384 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck::utils { + +union cvt +{ + float value_float; + uint32_t value_bitwise; +}; + +template +inline bool getDataHasInf() +{ + return DTYPE::dataInfo.hasInf; +} + +template +__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data); + +template +__host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data); + +template +__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data); + +template +__host__ __device__ inline int get_exponent_value(T x) +{ + x >>= NumericUtils::mant; + + x &= ((1 << NumericUtils::exp) - 1); + + return static_cast(x); +} + +template +__host__ __device__ inline bool is_subnormal(T x) +{ + return get_exponent_value(x) == 0; +} + +template +__host__ __device__ inline double get_mantissa_value(T x) +{ + double mantissa = is_subnormal(x) ? 0.0f : 1.0f; + + for(uint i = 0; i < NumericUtils::mant; i++) + { + + mantissa += std::pow(2, -int32_t((NumericUtils::mant - i))) * (x & 0b1); + + x >>= 1; + } + + return mantissa; +} + +template +__host__ __device__ inline bool get_data_has_inf() +{ + return NumericUtils::has_inf; +} + +template +__host__ __device__ float convert_to_float(T data, int scale_exp) +{ + float d_sign = + std::pow(-1, static_cast(data >> (NumericUtils::exp + NumericUtils::mant))); + + float d_exp; + if(is_subnormal(data)) + d_exp = std::pow(2, 1 - static_cast(NumericUtils::bias)); + else + d_exp = std::pow(2, get_exponent_value(data) - static_cast(NumericUtils::bias)); + float d_mant = get_mantissa_value(data); + + float data_value = d_sign * d_exp * d_mant; + float scale_value = std::pow( + 2, static_cast((scale_exp - static_cast(NumericUtils::bias)))); + + return data_value * scale_value; +} + +template +__host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data); + +template +__host__ __device__ T sat_convert_to_type(float value); + +template +__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed); + +template +inline T convert_to_type(float value) +{ + using bitwise_type = typename NumericUtils::bitwise_type; + + if(std::abs(value) > NumericLimits::Max()) + { + float max_value = NumericLimits::Max(); + + cvt t; + + // cppcheck-suppress redundantAssignment + t.value_float = max_value; + uint32_t max_bitwise = t.value_bitwise; + + // cppcheck-suppress redundantAssignment + t.value_float = value; + bitwise_type sign = + t.value_bitwise >> (NumericUtils::exp + NumericUtils::mant); + bitwise_type exp = + ((max_bitwise >> NumericUtils::mant) & NumericUtils::exp_mask) - + (NumericUtils::bias - NumericUtils::bias); + bitwise_type mantissa = max_bitwise >> (NumericUtils::mant - NumericUtils::mant); + + uint32_t mant_prev = max_bitwise >> (NumericUtils::mant - NumericUtils::mant); + mant_prev &= ((1 << NumericUtils::mant) - 1); + mant_prev--; + + mant_prev <<= (NumericUtils::mant - NumericUtils::mant); + uint32_t prev_bit = + ((max_bitwise >> NumericUtils::mant) << NumericUtils::mant) | mant_prev; + + t.value_bitwise = prev_bit; + float prev_val = t.value_float; + float diff = max_value - prev_val; + + float actual_max = max_value + (diff / 2); + + if(std::abs(value) < actual_max) + { + return sign << ((NumericUtils::exp + NumericUtils::mant)) | + (exp << NumericUtils::mant) | mantissa; + } + else + { + if(!get_data_has_inf()) + { + + return (1 << (NumericUtils::mant + NumericUtils::exp)) - 1; + } + else + { + exp++; + return sign << ((NumericUtils::exp + NumericUtils::mant)) | + (exp << NumericUtils::mant); + } + } + } + const int mfmt = NumericUtils::mant; + uint32_t x; + x = bit_cast(value); + + uint32_t head, mantissa; + int32_t exponent, bias; + uint32_t sign; + + head = x & NumericUtils::head_mask; + mantissa = x & NumericUtils::mant_mask; + exponent = (head >> NumericUtils::mant) & NumericUtils::exp_mask; + sign = head >> (NumericUtils::mant + NumericUtils::exp); + bias = NumericUtils::bias; + + if(x == 0) + { + return 0b0; + } + + const int mini_bias = NumericUtils::bias; + const int mini_denormal_act_exponent = 1 - mini_bias; + + int act_exponent, out_exponent, exponent_diff; + + bool is_subnorm = false; + + if(exponent == 0) + { + act_exponent = exponent - bias + 1; + exponent_diff = mini_denormal_act_exponent - act_exponent; + is_subnorm = true; + } + else + { + act_exponent = exponent - bias; + if(act_exponent <= mini_denormal_act_exponent) + { + exponent_diff = mini_denormal_act_exponent - act_exponent; + is_subnorm = true; + } + else + { + exponent_diff = 0; + } + mantissa += (1UL << mfmt); + } + + auto shift_amount = (mfmt - NumericUtils::mant + exponent_diff); + shift_amount = (shift_amount >= 64) ? 63 : shift_amount; + bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1)); + + float min_subnorm = NumericLimits::DataMinSubnorm() * (sign ? -1 : 1); + + if(is_subnorm && std::abs(value) < std::abs(min_subnorm)) + { + // closer to 0 + if(std::abs(value) <= std::abs(min_subnorm - value)) + return 0; + else + return 1 | (sign << (NumericUtils::exp + NumericUtils::mant)); + } + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << mfmt); + out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1); + + uint32_t drop_mask = (1UL << (mfmt - NumericUtils::mant)) - 1; + bool odd = mantissa & (1UL << (mfmt - NumericUtils::mant)); + mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask; + + if(out_exponent == 0) + { + if((1UL << mfmt) & mantissa) + { + out_exponent = 1; + } + } + else + { + if((1UL << (mfmt + 1)) & mantissa) + { + mantissa >>= 1; + out_exponent++; + } + } + + mantissa >>= (mfmt - NumericUtils::mant); + + if(out_exponent == 0 && mantissa == 0) + { + return 0; + } + + mantissa &= (1UL << NumericUtils::mant) - 1; + return (sign << (NumericUtils::exp + NumericUtils::mant)) | + (out_exponent << NumericUtils::mant) | mantissa; +} + +template +inline T convert_to_type_sr(float value, uint32_t seed) +{ + if(std::abs(value) > NumericLimits::Max()) + { + float max_value = NumericLimits::Max(); + + cvt t; + + // cppcheck-suppress redundantAssignment + t.value_float = max_value; + uint max_bitwise = t.value_bitwise; + + // cppcheck-suppress redundantAssignment + t.value_float = value; + T sign = t.value_bitwise >> (NumericUtils::exp + NumericUtils::mant); + T exp = ((max_bitwise >> NumericUtils::mant) & NumericUtils::exp_mask) - + (NumericUtils::bias - NumericUtils::bias); + + uint32_t mant_prev = max_bitwise >> (NumericUtils::mant - NumericUtils::mant); + mant_prev &= ((1UL << NumericUtils::mant) - 1); + mant_prev--; + + mant_prev <<= (NumericUtils::mant - NumericUtils::mant); + uint32_t prev_bit = + ((max_bitwise >> NumericUtils::mant) << NumericUtils::mant) | mant_prev; + + t.value_bitwise = prev_bit; + float prev_val = t.value_float; + float diff = max_value - prev_val; + + float actual_max = max_value + (diff / 2); + + if(std::abs(value) < actual_max) + { + double d_max_value = static_cast(max_value); + double d_actual_max = static_cast(actual_max); + double d_value = static_cast(value); + double d_is = std::abs(d_max_value - d_actual_max); + double d_seed = static_cast(seed); + double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is); // prob to round down + + double thresh = UINT_MAX * d_prob; + + if(!get_data_has_inf() || d_seed <= thresh) + // return static_cast(satConvertToType(getDataMax())); //round down time + return sign == 0 ? NumericUtils::data_max_positive_normal_mask + : NumericUtils::data_max_negative_normal_mask; + else + { + exp++; + return sign << ((NumericUtils::exp + NumericUtils::mant)) // inf + | (exp << NumericUtils::mant); + } + } + else + { + if(!get_data_has_inf()) + return (1 << (NumericUtils::mant + NumericUtils::exp)) - 1; + else + { + exp++; + return sign << ((NumericUtils::exp + NumericUtils::mant)) // inf + | (exp << NumericUtils::mant); + } + } + } + + uint32_t f32 = bit_cast(value); + + auto f32_mant = f32 & NumericUtils::mant_mask; + auto head = f32 & NumericUtils::head_mask; + auto f32_exp = (head >> NumericUtils::mant) & NumericUtils::exp_mask; + + auto sign_bit = head >> (NumericUtils::mant + NumericUtils::exp); + auto sign = sign_bit << (NumericUtils::exp + NumericUtils::mant); + + f32_exp = static_cast(f32_exp) - NumericUtils::bias; + int32_t exp = f32_exp; + auto mant = f32_mant; + bool subnorm = false; + + if(f32 == 0) + return 0b0; + + if(exp >= NumericUtils::unbiased_exp_min) + { + mant = f32_mant; + } + // if the exponent bit is 8, then the subnormal is exactly the same as f32 + else if(exp < NumericUtils::unbiased_exp_min && + NumericUtils::exp < NumericUtils::exp) + { + subnorm = true; + auto diff = static_cast(NumericUtils::unbiased_exp_min - exp); + if(diff >= 32) + { + mant = 0; + f32_mant = 0; + } + else + { + f32_mant |= static_cast(1) << NumericUtils::mant; + f32_mant >>= diff; + } + exp = 0; + mant = f32_mant; + } + + uint32_t sr_shift = NumericUtils::sr_shift; + + // For stochastic-rounding we add the aligned random value to the + // mantissa and then truncate (RTZ). + mant += seed >> sr_shift; + + // Increment exponent when mantissa overflows due to rounding + if(mant >= static_cast(1) << NumericUtils::mant) + ++exp; + mant >>= (NumericUtils::mant - NumericUtils::mant); + mant &= ((1 << NumericUtils::mant) - 1); + + auto biased_exp = static_cast(exp); + if(!subnorm) + biased_exp = static_cast(exp + NumericUtils::bias); + biased_exp &= ((1 << NumericUtils::exp) - 1); + auto val = sign | biased_exp << NumericUtils::mant | mant; + return val; +} + +} // namespace ck::utils diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp new file mode 100644 index 0000000000..5b7a822e1f --- /dev/null +++ b/include/ck/utility/scaled_type_convert.hpp @@ -0,0 +1,877 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/type_convert.hpp" +#include "ck/utility/mxf8_utils.hpp" + +#ifdef CK_USE_NATIVE_MX_SUPPORT +#define CK_USE_NATIVE_MX_SUPPORT 1 +#else +#define CK_USE_NATIVE_MX_SUPPORT 0 +#endif + +namespace ck { + +// Declare a template function for scaled conversion +template +#if CK_USE_OCP_FP8 +__host__ __device__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x); +#else +__host__ constexpr Y scaled_type_convert(e8m0_bexp_t scale, X x); +#endif + +// convert f8_ocp_t to fp32 +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ float scaled_type_convert(e8m0_bexp_t scale, f8_ocp_t x) +#else +inline __host__ float scaled_type_convert(e8m0_bexp_t scale, f8_ocp_t x) +#endif +{ + +#if CK_MX_FP8_CVT_FAST_PATH + return fp8_impl::cast_to_f32_from_f8_scaled( + type_convert(scale), x.data); +#else + return type_convert(scale) * type_convert(x); +#endif +} + +// convert bf8_ocp_t to fp32 +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ float scaled_type_convert(e8m0_bexp_t scale, + bf8_ocp_t x) +#else +inline __host__ float scaled_type_convert(e8m0_bexp_t scale, bf8_ocp_t x) +#endif +{ + +#if CK_MX_FP8_CVT_FAST_PATH + return fp8_impl::cast_to_f32_from_f8_scaled( + type_convert(scale), x.data); +#else + return type_convert(scale) * type_convert(x); +#endif +} + +// convert 2 x f8_ocp_t to 2 x fp32 +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ float2_t scaled_type_convert(e8m0_bexp_t scale, + f8x2_ocp_t x) +#else +inline __host__ float2_t scaled_type_convert(e8m0_bexp_t scale, f8x2_ocp_t x) +#endif +{ +#if CK_MX_FP8_CVT_FAST_PATH + return fp8_impl::cast_to_f32x2_from_f8x2_scaled( + type_convert(scale), x.AsType()[Number<0>{}]); +#else + return float2_t{scaled_type_convert(scale, x.AsType()[Number<0>{}]), + scaled_type_convert(scale, x.AsType()[Number<1>{}])}; +#endif +} + +// convert 2 x bf8_ocp_t to 2 x fp32 +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ float2_t scaled_type_convert(e8m0_bexp_t scale, + bf8x2_ocp_t x) +#else +inline __host__ float2_t scaled_type_convert(e8m0_bexp_t scale, + bf8x2_ocp_t x) +#endif +{ +#if CK_MX_FP8_CVT_FAST_PATH + return fp8_impl::cast_to_f32x2_from_f8x2_scaled( + type_convert(scale), x.AsType()[Number<0>{}]); +#else + return float2_t{scaled_type_convert(scale, x.AsType()[Number<0>{}]), + scaled_type_convert(scale, x.AsType()[Number<1>{}])}; +#endif +} + +// convert 16 x f8_ocp_t to 16 x fp32 +// @note Host version gives compilation error. Requires extra compiler options. +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ float16_t scaled_type_convert(e8m0_bexp_t scale, + f8x16_ocp_t x) +#else +inline __host__ float16_t scaled_type_convert(e8m0_bexp_t scale, + f8x16_ocp_t x) +#endif +{ + union + { + f8x16_ocp_t f8_1x16; + f8x2_ocp_t f8_2x8[8]; + } in{x}; + union + { + float16_t float_1x16; + float2_t float_2x8[8]; + } out{}; + + ck::static_for<0, 8, 1>{}([&](auto i) { + out.float_2x8[i] = scaled_type_convert(scale, in.f8_2x8[i]); + }); + + return out.float_1x16; +} + +// convert 16 x bf8_ocp_t to 16 x fp32 +// @note Host version gives compilation error. Requires extra compiler options. +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ float16_t scaled_type_convert(e8m0_bexp_t scale, + bf8x16_ocp_t x) +#else +inline __host__ float16_t scaled_type_convert(e8m0_bexp_t scale, + bf8x16_ocp_t x) +#endif +{ + union + { + bf8x16_ocp_t bf8_1x16; + bf8x2_ocp_t bf8_2x8[8]; + } in{x}; + union + { + float16_t float_1x16; + float2_t float_2x8[8]; + } out{}; + + ck::static_for<0, 8, 1>{}([&](auto i) { + out.float_2x8[i] = scaled_type_convert(scale, in.bf8_2x8[i]); + }); + + return out.float_1x16; +} + +// convert 32 x f8_ocp_t to 32 x fp32 +// @note Host version gives compilation error. Requires extra compiler options. +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ float32_t scaled_type_convert(e8m0_bexp_t scale, + f8x32_ocp_t x) +#else +inline __host__ float32_t scaled_type_convert(e8m0_bexp_t scale, + f8x32_ocp_t x) +#endif +{ + union + { + f8x32_ocp_t f8_1x32; + f8x16_ocp_t f8_16x2[2]; + } in{x}; + union + { + float32_t float_1x32; + float16_t float_16x2[2]; + } out{}; + + ck::static_for<0, 2, 1>{}([&](auto i) { + out.float_16x2[i] = scaled_type_convert(scale, in.f8_16x2[i]); + }); + + return out.float_1x32; +} + +// convert 32 x bf8_ocp_t to 32 x fp32 +// @note Host version gives compilation error. Requires extra compiler options. +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ float32_t scaled_type_convert(e8m0_bexp_t scale, + bf8x32_ocp_t x) +#else +inline __host__ float32_t scaled_type_convert(e8m0_bexp_t scale, + bf8x32_ocp_t x) +#endif +{ + union + { + bf8x32_ocp_t bf8_1x32; + bf8x16_ocp_t bf8_16x2[2]; + } in{x}; + union + { + float32_t float_1x32; + float16_t float_16x2[2]; + } out{}; + + ck::static_for<0, 2, 1>{}([&](auto i) { + out.float_16x2[i] = scaled_type_convert(scale, in.bf8_16x2[i]); + }); + + return out.float_1x32; +} + +// convert fp32 to fp8 +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ f8_ocp_t scaled_type_convert(e8m0_bexp_t scale, float x) +#else +inline __host__ f8_ocp_t scaled_type_convert(e8m0_bexp_t scale, float x) +#endif +{ +#if CK_USE_SR_F8_CONVERSION + return mxf8_convert_sr(x, type_convert(scale)); +#else + return mxf8_convert_rne(x, type_convert(scale)); +#endif +} + +// convert fp32 to bf8 +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ bf8_ocp_t scaled_type_convert(e8m0_bexp_t scale, + float x) +#else +inline __host__ bf8_ocp_t scaled_type_convert(e8m0_bexp_t scale, float x) +#endif +{ +#if CK_USE_SR_F8_CONVERSION + return mxf8_convert_sr(x, type_convert(scale)); +#else + return mxf8_convert_rne(x, type_convert(scale)); +#endif +} + +// convert fp32x2 to fp8x2 +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ f8x2_ocp_t scaled_type_convert(e8m0_bexp_t scale, + float2_t x) +#else +inline __host__ f8x2_ocp_t scaled_type_convert(e8m0_bexp_t scale, float2_t x) +#endif +{ +#if CK_USE_SR_F8_CONVERSION + return mxf8_convert_sr(x, type_convert(scale)); +#else + return mxf8_convert_rne(x, type_convert(scale)); +#endif +} +// convert fp32x2 to bf8x2 +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ bf8x2_ocp_t scaled_type_convert(e8m0_bexp_t scale, + float2_t x) +#else +inline __host__ bf8x2_ocp_t scaled_type_convert(e8m0_bexp_t scale, + float2_t x) +#endif +{ +#if CK_USE_SR_F8_CONVERSION + return mxf8_convert_sr(x, type_convert(scale)); +#else + return mxf8_convert_rne(x, type_convert(scale)); +#endif +} + +// convert fp32x16 to fp8x16 +// @note Host version gives compilation error. Requires extra compiler options. +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ f8x16_ocp_t +scaled_type_convert(e8m0_bexp_t scale, float16_t x) +#else +inline __host__ f8x16_ocp_t scaled_type_convert(e8m0_bexp_t scale, + float16_t x) +#endif +{ +#if CK_USE_SR_F8_CONVERSION + return mxf8_convert_sr(x, type_convert(scale)); +#else + return mxf8_convert_rne(x, type_convert(scale)); +#endif +} + +// convert fp32x16 to bf8x16 +// @note Host version gives compilation error. Requires extra compiler options. +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ bf8x16_ocp_t +scaled_type_convert(e8m0_bexp_t scale, float16_t x) +#else +inline __host__ bf8x16_ocp_t scaled_type_convert(e8m0_bexp_t scale, + float16_t x) +#endif +{ +#if CK_USE_SR_F8_CONVERSION + return mxf8_convert_sr(x, type_convert(scale)); +#else + return mxf8_convert_rne(x, type_convert(scale)); +#endif +} + +// convert fp32x32 to fp8x32 +// @note Host version gives compilation error. Requires extra compiler options. +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ f8x32_ocp_t +scaled_type_convert(e8m0_bexp_t scale, float32_t x) +#else +inline __host__ f8x32_ocp_t scaled_type_convert(e8m0_bexp_t scale, + float32_t x) +#endif +{ +#if CK_USE_SR_F8_CONVERSION + return mxf8_convert_sr(x, type_convert(scale)); +#else + return mxf8_convert_rne(x, type_convert(scale)); +#endif +} + +// convert fp32x32 to bf8x32 +// @note Host version gives compilation error. Requires extra compiler options. +template <> +#if CK_USE_OCP_FP8 +inline __host__ __device__ bf8x32_ocp_t +scaled_type_convert(e8m0_bexp_t scale, float32_t x) +#else +inline __host__ bf8x32_ocp_t scaled_type_convert(e8m0_bexp_t scale, + float32_t x) +#endif +{ +#if CK_USE_SR_F8_CONVERSION + return mxf8_convert_sr(x, type_convert(scale)); +#else + return mxf8_convert_rne(x, type_convert(scale)); +#endif +} + +// activate for architectures with native MX support +#if CK_USE_NATIVE_MX_SUPPORT +// convert fp4 to fp32 +template <> +inline __host__ __device__ float scaled_type_convert(e8m0_bexp_t scale, f4_t x) +{ +#if defined(__gfx950__) + union + { + float float_array[2]; + float2_t float2_array; + } float_values{}; + float_values.float2_array = + __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, type_convert(scale), 0); + return float_values.float_array[0]; +#else + return utils::to_float(scale, x); +#endif +} + +// convert vector of 2 fp4 to vector of 2 fp32 +template <> +inline __host__ __device__ float2_t scaled_type_convert(e8m0_bexp_t scale, + f4x2_t x) +{ +#if defined(__gfx950__) + union + { + uint32_t bitwise; + f4x2_t f4x2_array[4]; + } value{}; + value.f4x2_array[0] = x; + return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert(scale), 0); +#else + float2_t ret{utils::to_float( + scale, x.template AsType()[Number<0>{}].unpack<>(Number<1>{})), + utils::to_float( + scale, x.template AsType()[Number<0>{}].unpack<>(Number<0>{}))}; + return ret; +#endif +} + +// convert vector of 32 fp4 to vector of 32 fp32 +template <> +inline __host__ __device__ float32_t scaled_type_convert(e8m0_bexp_t scale, + f4x32_t x) +{ +#if defined(__gfx950__) + union + { + f4x32_t f4x32_array; + f4x2_t fp4x2[16]; + } value{x}; + union + { + uint32_t bitwise; + f4x2_t f4x2_array[4]; + } bitwise_value{}; + float2_t op; + float32_t ret; + // TODO: pack in a loop + bitwise_value.f4x2_array[0] = value.fp4x2[0]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[0] = op[0]; + ret[1] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[2] = op[0]; + ret[3] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[2]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[4] = op[0]; + ret[5] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[3]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[6] = op[0]; + ret[7] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[4]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[8] = op[0]; + ret[9] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[5]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[10] = op[0]; + ret[11] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[6]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[12] = op[0]; + ret[13] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[7]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[14] = op[0]; + ret[15] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[8]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[16] = op[0]; + ret[17] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[9]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[18] = op[0]; + ret[19] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[10]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[20] = op[0]; + ret[21] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[11]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[22] = op[0]; + ret[23] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[12]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[24] = op[0]; + ret[25] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[13]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[26] = op[0]; + ret[27] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[14]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[28] = op[0]; + ret[29] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[15]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[30] = op[0]; + ret[31] = op[1]; + + return ret; +#else + union + { + float32_t float32_array; + float float_array[32]; + } float_values{}; + union + { + __uint128_t bitwise; + f4x2_t f4x2_array[16]; + f4x32_t f4x32_array; + } f4_values{bit_cast<__uint128_t>(x)}; + // TODO: pack in a loop + float_values.float_array[0] = utils::to_float( + scale, + f4_values.f4x2_array[0].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[1] = utils::to_float( + scale, + f4_values.f4x2_array[0].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[2] = utils::to_float( + scale, + f4_values.f4x2_array[1].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[3] = utils::to_float( + scale, + f4_values.f4x2_array[1].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[4] = utils::to_float( + scale, + f4_values.f4x2_array[2].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[5] = utils::to_float( + scale, + f4_values.f4x2_array[2].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[6] = utils::to_float( + scale, + f4_values.f4x2_array[3].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[7] = utils::to_float( + scale, + f4_values.f4x2_array[3].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + + float_values.float_array[0] = utils::to_float( + scale, + f4_values.f4x2_array[4].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[1] = utils::to_float( + scale, + f4_values.f4x2_array[4].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[2] = utils::to_float( + scale, + f4_values.f4x2_array[5].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[3] = utils::to_float( + scale, + f4_values.f4x2_array[5].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[4] = utils::to_float( + scale, + f4_values.f4x2_array[6].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[5] = utils::to_float( + scale, + f4_values.f4x2_array[6].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[6] = utils::to_float( + scale, + f4_values.f4x2_array[7].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[7] = utils::to_float( + scale, + f4_values.f4x2_array[7].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + + float_values.float_array[0] = utils::to_float( + scale, + f4_values.f4x2_array[8].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[1] = utils::to_float( + scale, + f4_values.f4x2_array[8].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[2] = utils::to_float( + scale, + f4_values.f4x2_array[9].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[3] = utils::to_float( + scale, + f4_values.f4x2_array[9].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[4] = utils::to_float( + scale, + f4_values.f4x2_array[10].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[5] = utils::to_float( + scale, + f4_values.f4x2_array[10].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[6] = utils::to_float( + scale, + f4_values.f4x2_array[11].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[7] = utils::to_float( + scale, + f4_values.f4x2_array[11].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + + float_values.float_array[0] = utils::to_float( + scale, + f4_values.f4x2_array[12].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[1] = utils::to_float( + scale, + f4_values.f4x2_array[12].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[2] = utils::to_float( + scale, + f4_values.f4x2_array[13].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[3] = utils::to_float( + scale, + f4_values.f4x2_array[13].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[4] = utils::to_float( + scale, + f4_values.f4x2_array[14].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[5] = utils::to_float( + scale, + f4_values.f4x2_array[14].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[6] = utils::to_float( + scale, + f4_values.f4x2_array[15].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[7] = utils::to_float( + scale, + f4_values.f4x2_array[15].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + + return float_values.float32_array; +#endif +} + +// convert fp32 to fp4 +template <> +inline __host__ __device__ f4_t scaled_type_convert(e8m0_bexp_t scale, float x) +{ +#if CK_USE_SR_F4_CONVERSION + return f4_convert_sr(x, type_convert(scale)); +#else + return f4_convert_rne(x, type_convert(scale)); +#endif +} + +// convert vector of 2 fp32 to vector of 2 fp4 +template <> +inline __host__ __device__ f4x2_t scaled_type_convert(e8m0_bexp_t scale, + float2_t x) +{ +#if CK_USE_SR_F4_CONVERSION + return f4_convert_sr(x, type_convert(scale)); +#else + return f4_convert_rne(x, type_convert(scale)); +#endif +} + +// convert vector of 32 fp32 to vector of 32 fp4 +template <> +inline __host__ __device__ f4x32_t scaled_type_convert(e8m0_bexp_t scale, + float32_t x) +{ +#if CK_USE_SR_F4_CONVERSION + return f4_convert_sr(x, type_convert(scale)); +#else + return f4_convert_rne(x, type_convert(scale)); +#endif +} + +/** + * @brief Converts a 6-bit floating-point value (f6_t) to a 32-bit float, + * applying the specified scaling factor. + * + * @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t. + * @param x The f6_t value to be converted. + * @return The converted 32-bit float representation of the input. + */ +template <> +inline __host__ __device__ float scaled_type_convert(e8m0_bexp_t scale, f6_t x) +{ +#if defined(__gfx950__) + union + { + f6x32_t f6_vector; + f6_t f6_array[32]; + } in{x}; + + union + { + float32_t float_vector; + float float_array[32]; + } out{}; + + out.float_vector = + __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(in.f6_vector, type_convert(scale)); + return out.float_array[0]; +#else + return utils::to_float(scale, x); +#endif +} + +/** + * @brief Converts a vector of 32 6-bit floating-point values (f6x32_t) to a vector of 32 floats, + * applying the specified scaling factor. + * + * @param scale The exponent scale factor (e8m0_bexp_t). + * @param x The f6x32_t vector to be converted. + * @return The converted float vector representation of the input. + */ +template <> +inline __host__ __device__ float32_t scaled_type_convert(e8m0_bexp_t scale, + f6x32_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(x, type_convert(scale)); +#else + union + { + f6x32_t f6_vector; + f6_t f6_array[32]; + } in{x}; + + union + { + float32_t float_vector; + float float_array[32]; + } out{}; + + ck::static_for<0, 32, 1>{}( + [&](auto i) { out.float_array[i] = utils::to_float(scale, in.f6_array[i]); }); + + return out.float_vector; +#endif +} + +/** + * @brief Converts a 6-bit floating-point value (bf6_t) to a 32-bit float, + * applying the specified scaling factor. + * + * @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t. + * @param x The bf6_t value to be converted. + * @return The converted 32-bit float representation of the input. + */ +template <> +inline __host__ __device__ float scaled_type_convert(e8m0_bexp_t scale, bf6_t x) +{ +#if defined(__gfx950__) + union + { + bf6x32_t bf6_vector; + bf6_t bf6_array[32]; + } in{x}; + + union + { + float32_t float_vector; + float float_array[32]; + } out{}; + + out.float_vector = + __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(in.bf6_vector, type_convert(scale)); + return out.float_array[0]; +#else + return utils::to_float(scale, x); +#endif +} + +/** + * @brief Converts a vector of 6-bit floating-point values (bf6x32_t) to a vector of 32 floats, + * applying the specified scaling factor. + * + * @param scale The exponent scale factor (e8m0_bexp_t). + * @param x The bf6x32_t vector to be converted. + * @return The converted vector of 32 float representation of the input. + */ +template <> +inline __host__ __device__ float32_t scaled_type_convert(e8m0_bexp_t scale, + bf6x32_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(x, type_convert(scale)); +#else + union + { + bf6x32_t bf6_vector; + bf6_t bf6_array[32]; + } in{x}; + + union + { + float32_t float_vector; + float float_array[32]; + } out{}; + + ck::static_for<0, 32, 1>{}( + [&](auto i) { out.float_array[i] = utils::to_float(scale, in.bf6_array[i]); }); + + return out.float_vector; +#endif +} + +/** + * @brief Converts a 32-bit float to a 6-bit floating-point value (f6_t), applying the specified + * scale. + * + * Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding + * (f6_convert_sr) or round-to-nearest-even (f6_convert_rne). + * + * @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t. + * @param x The float value to convert. + * @return The converted 6-bit floating-point value (f6_t). + */ +template <> +inline __host__ __device__ f6_t scaled_type_convert(e8m0_bexp_t scale, float x) +{ +#if CK_USE_SR_F6_CONVERSION + return f6_convert_sr(x, type_convert(scale)); +#else + return f6_convert_rne(x, type_convert(scale)); +#endif +} + +/** + * @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (f6x32_t), + * applying the specified scale. + * + * Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding + * (f6_convert_sr) or round-to-nearest-even (f6_convert_rne). + * + * @param scale The exponent scale factor (e8m0_bexp_t). + * @param x The float vector to convert. + * @return The converted vector of 6-bit floating-point values (f6x32_t). + */ +template <> +inline __host__ __device__ f6x32_t scaled_type_convert(e8m0_bexp_t scale, + float32_t x) +{ +#if CK_USE_SR_F6_CONVERSION + return f6_convert_sr(x, type_convert(scale)); +#else + return f6_convert_rne(x, type_convert(scale)); +#endif +} + +/** + * @brief Converts a 32-bit float to a 6-bit floating-point value (bf6_t), applying the specified + * scale. + * + * Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding + * (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne). + * + * @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t. + * @param x The float value to convert. + * @return The converted 6-bit floating-point value (bf6_t). + */ +template <> +inline __host__ __device__ bf6_t scaled_type_convert(e8m0_bexp_t scale, float x) +{ +#if CK_USE_SR_F6_CONVERSION + return bf6_convert_sr(x, type_convert(scale)); +#else + return bf6_convert_rne(x, type_convert(scale)); +#endif +} + +/** + * @brief Converts a vector of 32 floats to a vector of 32 6-bit floating-point values (bf6x32_t), + * applying the specified scale. + * + * Depending on whether CK_USE_SR_F6_CONVERSION is defined, it uses either stochastic rounding + * (bf6_convert_sr) or round-to-nearest-even (bf6_convert_rne). + * + * @param scale The exponent scale factor (e8m0_bexp_t). + * @param x The float vector to convert. + * @return The converted 6-bit floating-point vector (bf6x32_t). + */ +template <> +inline __host__ __device__ bf6x32_t scaled_type_convert(e8m0_bexp_t scale, + float32_t x) +{ +#if CK_USE_SR_F6_CONVERSION + return bf6_convert_sr(x, type_convert(scale)); +#else + return bf6_convert_rne(x, type_convert(scale)); +#endif +} +#endif // #if CK_USE_NATIVE_MX_SUPPORT + +} // namespace ck diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index a86de19645..e9b2e3fff2 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -5,6 +5,8 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/f8_utils.hpp" +#include "ck/utility/mxf4_utils.hpp" +#include "ck/utility/mxf6_utils.hpp" #include "ck/utility/random_gen.hpp" #include "ck/utility/array.hpp" #include "ck/utility/amd_inline_asm.hpp" @@ -12,7 +14,7 @@ namespace ck { // Define the common macro for MI300 models -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__) #define __gfx94__ #endif @@ -707,6 +709,1278 @@ inline __host__ __device__ half_t type_convert(bf8_fnuz_t x) #endif } +// convert fp32 to fp4 with rounding to nearest even +inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f) +{ +#if defined(__gfx950__) + union + { + uint32_t bitwise; + f4_t f4_array[4]; + } value{0}; + value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x, x, scale, 0); + return value.f4_array[0]; +#else + return utils::sat_convert_to_type(x / scale); +#endif +} + +// convert vector of 2 fp32 to vector of 2 fp4 with rne +inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f) +{ +#if defined(__gfx950__) + union + { + uint32_t bitwise; + f4x2_t f4x2_array[4]; + } value{0}; + value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0); + return value.f4x2_array[0]; +#else + union + { + uint32_t bitwise; + f4x2_t f4x2_array[4]; + } value{0}; + uint8_t l = utils::sat_convert_to_type(x[1] / scale); + uint8_t h = utils::sat_convert_to_type(x[0] / scale); + value.bitwise = (h << 4) | l; + return value.f4x2_array[0]; +#endif +} + +// convert vector of 32 fp32 to vector of 32 fp4 with rne +inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0f) +{ +#if defined(__gfx950__) + union + { + __uint128_t bitwise; + f4x2_t f4x2_array[16]; + f4x32_t f4x32_array; + } f4_values{}, tmp_values{}; + // TODO: pack in a loop + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[0], x[1], scale, 0); + f4_values.f4x2_array[0] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[2], x[3], scale, 0); + f4_values.f4x2_array[1] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[4], x[5], scale, 0); + f4_values.f4x2_array[2] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[6], x[7], scale, 0); + f4_values.f4x2_array[3] = tmp_values.f4x2_array[0]; + + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[8], x[9], scale, 0); + f4_values.f4x2_array[4] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[10], x[11], scale, 0); + f4_values.f4x2_array[5] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[12], x[13], scale, 0); + f4_values.f4x2_array[6] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[14], x[15], scale, 0); + f4_values.f4x2_array[7] = tmp_values.f4x2_array[0]; + + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[16], x[17], scale, 0); + f4_values.f4x2_array[8] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[18], x[19], scale, 0); + f4_values.f4x2_array[9] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[20], x[21], scale, 0); + f4_values.f4x2_array[10] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[22], x[23], scale, 0); + f4_values.f4x2_array[11] = tmp_values.f4x2_array[0]; + + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[24], x[25], scale, 0); + f4_values.f4x2_array[12] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[26], x[27], scale, 0); + f4_values.f4x2_array[13] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[28], x[29], scale, 0); + f4_values.f4x2_array[14] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = + __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[30], x[31], scale, 0); + f4_values.f4x2_array[15] = tmp_values.f4x2_array[0]; + + return f4_values.f4x32_array; +#else + union + { + __uint128_t bitwise; + f4x2_t f4x2_array[16]; + f4x32_t f4x32_array; + } f4_values{}; + // TODO: pack in a loop + auto tmp = utils::sat_convert_to_type(x[0] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[1] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[2] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[3] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[4] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[5] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[6] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[7] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + + tmp = utils::sat_convert_to_type(x[8] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[9] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[10] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[11] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[12] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[13] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[14] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[15] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + + tmp = utils::sat_convert_to_type(x[16] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[17] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[18] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[19] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[20] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[21] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[22] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[23] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + + tmp = utils::sat_convert_to_type(x[24] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[25] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[26] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[27] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[28] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[29] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[30] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type(x[31] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + + return f4_values.f4x32_array; +#endif +} + +// convert fp32 to fp4 with stochastic rounding +inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) +{ + constexpr int seed = 1254739; + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#if defined(__gfx950__) + union + { + uint32_t bitwise; + f4_t f4_array[4]; + } value{0}; + union + { + float float_array[2]; + float2_t float2_array; + } float_values{{x}}; + + value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + value.bitwise, float_values.float2_array, rng, scale, 0); + return value.f4_array[0]; +#else + return utils::sat_convert_to_type_sr(x / scale, rng); +#endif +} + +// convert vector of 2 fp32 to vector of 2 fp4 with sr +inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) +{ + constexpr int seed = 1254739; + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#if defined(__gfx950__) + union + { + uint32_t bitwise; + f4x2_t f4x2_array[4]; + } value{0}; + value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0); + return value.f4x2_array[0]; +#else + union + { + uint32_t bitwise; + f4x2_t f4x2_array[4]; + } value{0}; + uint8_t l = utils::sat_convert_to_type_sr(x[1] / scale, rng); + uint8_t h = utils::sat_convert_to_type_sr(x[0] / scale, rng); + value.bitwise = (h << 4) | l; + return value.f4x2_array[0]; +#endif +} + +// convert vector of 32 fp32 to vector of 32 fp4 with sr +inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f) +{ + constexpr int seed = 1254739; + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#if defined(__gfx950__) + union + { + __uint128_t bitwise; + f4x2_t f4x2_array[16]; + f4x32_t f4x32_array; + } f4_values{0}, tmp_values{0}; + union + { + float2_t floatx2_array[16]; + float32_t floatx32_array; + } float_values{{0}}; + // TODO: pack in a loop + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[0], rng, scale, 0); + f4_values.f4x2_array[0] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[1], rng, scale, 0); + f4_values.f4x2_array[1] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[2], rng, scale, 0); + f4_values.f4x2_array[2] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[3], rng, scale, 0); + f4_values.f4x2_array[3] = tmp_values.f4x2_array[0]; + + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[4], rng, scale, 0); + f4_values.f4x2_array[4] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[5], rng, scale, 0); + f4_values.f4x2_array[5] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[6], rng, scale, 0); + f4_values.f4x2_array[6] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[7], rng, scale, 0); + f4_values.f4x2_array[7] = tmp_values.f4x2_array[0]; + + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[8], rng, scale, 0); + f4_values.f4x2_array[8] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[9], rng, scale, 0); + f4_values.f4x2_array[9] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[10], rng, scale, 0); + f4_values.f4x2_array[10] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[11], rng, scale, 0); + f4_values.f4x2_array[11] = tmp_values.f4x2_array[0]; + + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[12], rng, scale, 0); + f4_values.f4x2_array[12] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[13], rng, scale, 0); + f4_values.f4x2_array[13] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[14], rng, scale, 0); + f4_values.f4x2_array[14] = tmp_values.f4x2_array[0]; + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + tmp_values.bitwise, float_values.floatx2_array[15], rng, scale, 0); + f4_values.f4x2_array[15] = tmp_values.f4x2_array[0]; + + return f4_values.f4x32_array; +#else + union + { + __uint128_t bitwise; + f4x2_t f4x2_array[16]; + f4x32_t f4x32_array; + } f4_values{0}; + // TODO: pack in a loop + auto tmp = utils::sat_convert_to_type_sr(x[0] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[1] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[2] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[3] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[4] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[5] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[6] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[7] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + + tmp = utils::sat_convert_to_type_sr(x[8] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[9] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[10] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[11] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[12] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[13] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[14] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[15] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + + tmp = utils::sat_convert_to_type_sr(x[16] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[17] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[18] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[19] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[20] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[21] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[22] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[23] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + + tmp = utils::sat_convert_to_type_sr(x[24] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[25] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[26] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[27] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[28] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[29] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[30] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + tmp = utils::sat_convert_to_type_sr(x[31] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + + return f4_values.f4x32_array; +#endif +} + +// convert fp32 to fp4 +template <> +inline __host__ __device__ f4_t type_convert(float x) +{ +#if CK_USE_SR_F4_CONVERSION + return f4_convert_sr(x); +#else + return f4_convert_rne(x); +#endif +} + +// convert vector of 2 fp32 to vector of 2 fp4 +template <> +inline __host__ __device__ f4x2_t type_convert(float2_t x) +{ +#if CK_USE_SR_F4_CONVERSION + return f4_convert_sr(x); +#else + return f4_convert_rne(x); +#endif +} + +// convert vector of 32 fp32 to vector of 32 fp4 +template <> +inline __host__ __device__ f4x32_t type_convert(float32_t x) +{ +#if CK_USE_SR_F4_CONVERSION + return f4_convert_sr(x); +#else + return f4_convert_rne(x); +#endif +} + +// convert fp4 to fp32 +template <> +inline __host__ __device__ float type_convert(f4_t x) +{ +#if defined(__gfx950__) + union + { + float float_array[2]; + float2_t float2_array; + } float_values{}; + float scale = 1.0f; + float_values.float2_array = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0); + return float_values.float_array[0]; +#else + return utils::to_float(NumericLimits::Binary_1(), x); +#endif +} + +// convert vector of 2 fp4 to vector of 2 fp32 +template <> +inline __host__ __device__ float2_t type_convert(f4x2_t x) +{ +#if defined(__gfx950__) + union + { + uint32_t bitwise; + f4x2_t f4x2_array[4]; + } value{}; + value.f4x2_array[0] = x; + float scale = 1.0f; + return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0); +#else + float2_t ret{ + utils::to_float(NumericLimits::Binary_1(), + x.template AsType()[Number<0>{}].unpack<>(Number<1>{})), + utils::to_float(NumericLimits::Binary_1(), + x.template AsType()[Number<0>{}].unpack<>(Number<0>{}))}; + return ret; +#endif +} + +// convert vector of 32 fp4 to vector of 32 fp32 +template <> +inline __host__ __device__ float32_t type_convert(f4x32_t x) +{ +#if defined(__gfx950__) + union + { + f4x32_t f4x32_array; + f4x2_t fp4x2[16]; + } value{x}; + union + { + uint32_t bitwise; + f4x2_t f4x2_array[4]; + } bitwise_value{}; + float2_t op; + float32_t ret; + float scale = 1.0f; + // TODO: pack in a loop + bitwise_value.f4x2_array[0] = value.fp4x2[0]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[0] = op[0]; + ret[1] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[1]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[2] = op[0]; + ret[3] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[2]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[4] = op[0]; + ret[5] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[3]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[6] = op[0]; + ret[7] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[4]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[8] = op[0]; + ret[9] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[5]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[10] = op[0]; + ret[11] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[6]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[12] = op[0]; + ret[13] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[7]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[14] = op[0]; + ret[15] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[8]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[16] = op[0]; + ret[17] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[9]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[18] = op[0]; + ret[19] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[10]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[20] = op[0]; + ret[21] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[11]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[22] = op[0]; + ret[23] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[12]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[24] = op[0]; + ret[25] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[13]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[26] = op[0]; + ret[27] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[14]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[28] = op[0]; + ret[29] = op[1]; + + bitwise_value.f4x2_array[0] = value.fp4x2[15]; + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( + bitwise_value.bitwise, type_convert(scale), 0); + ret[30] = op[0]; + ret[31] = op[1]; + + return ret; +#else + union + { + float32_t float32_array; + float float_array[32]; + } float_values{}; + union + { + __uint128_t bitwise; + f4x2_t f4x2_array[16]; + f4x32_t f4x32_array; + } f4_values{bit_cast<__uint128_t>(x)}; + // TODO: pack in a loop + float_values.float_array[0] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[0].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[1] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[0].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[2] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[1].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[3] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[1].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[4] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[2].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[5] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[2].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[6] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[3].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[7] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[3].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + + float_values.float_array[0] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[4].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[1] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[4].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[2] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[5].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[3] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[5].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[4] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[6].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[5] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[6].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[6] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[7].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[7] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[7].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + + float_values.float_array[0] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[8].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[1] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[8].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[2] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[9].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[3] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[9].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[4] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[10].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[5] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[10].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[6] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[11].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[7] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[11].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + + float_values.float_array[0] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[12].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[1] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[12].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[2] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[13].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[3] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[13].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[4] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[14].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[5] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[14].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + float_values.float_array[6] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[15].template AsType()[Number<0>{}].unpack<>(Number<0>{})); + float_values.float_array[7] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[15].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + + return float_values.float32_array; +#endif +} + +/** + * @brief Converts a float to a 6-bit float type (f6_t) using round-to-nearest-even. + * + * Divides the input by the specified scale, then saturates and converts it + * to the 6-bit floating-point format (f6_t). + * + * @param x The input float value. + * @param scale A scaling factor applied to `x` before conversion. + * @return The converted f6_t value. + */ +inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f) +{ +#if defined(__gfx950__) + float16_t in1{x}; + float16_t in2{}; + + union + { + f6x32_t f6_vector; + f6_t f6_array[32]; + } out{}; + + out.f6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale); + + return out.f6_array[0]; +#else + return utils::sat_convert_to_type(x / scale); +#endif +} + +/** + * @brief Converts a 32-element single-precision float array into a packed 6-bit representation. + * + * This function divides each input float by the provided scale value, then performs conversion with + * rounding to nearest / even to pack each element into 6 bits of precision. + * + * @param x A vector of 32 floats stored in float32_t. + * @param scale A scaling factor for each float before conversion. + * @return An f6x32_t object storing the compressed 6-bit representation. + */ +inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0f) +{ +#if defined(__gfx950__) + float16_t* in1 = reinterpret_cast(&x); + float16_t* in2 = reinterpret_cast(&x + 16); + return __builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale); +#else + union + { + float32_t float_vector; + float float_array[32]; + } in{x}; + + union + { + f6x32_t f6_vector; + f6_t f6_array[32]; + } out{}; + + ck::static_for<0, 32, 1>{}([&](auto i) { + out.f6_array[i] = utils::sat_convert_to_type(in.float_array[i] / scale); + }); + + return out.f6_vector; +#endif +} + +/** + * @brief Converts a float to the 6-bit floating-point type (f6_t) using stochastic rounding. + * + * Divides the input by the specified scale, then performs saturation and conversion + * to f6_t based on a pseudo-randomly generated seed. + * + * @param x The input float value. + * @param scale A scaling factor applied to `x` before conversion. + * @return The converted f6_t value. + */ +inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) +{ + constexpr int seed = 1254739; + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#if defined(__gfx950__) + union + { + float32_t float_vector; + float float_array[32]; + } in{x}; + + union + { + f6x32_t f6_vector; + f6_t f6_array[32]; + } out{}; + + out.f6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale); + + return out.f6_array[0]; +#else + return utils::sat_convert_to_type_sr(x / scale, rng); +#endif +} + +/** + * @brief Converts a 32-element single-precision float array into a packed 6-bit representation. + * + * This function divides each input float by the provided scale value, then performs conversion with + * stochastic rounding to pack each element into 6 bits of precision. + * + * @param x A vector of 32 floats stored in float32_t. + * @param scale A scaling factor for each float before conversion. + * @return An f6x32_t object storing the compressed 6-bit representation. + */ +inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f) +{ + constexpr int seed = 1254739; + union + { + float32_t float_vector; + float float_array[32]; + } float_values{x}; + uint32_t rng = + prand_generator(reinterpret_cast(&x), float_values.float_array[0]); +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale); +#else + union + { + float32_t float_vector; + float float_array[32]; + } in{x}; + + union + { + f6x32_t f6_vector; + f6_t f6_array[32]; + } out{}; + + ck::static_for<0, 32, 1>{}([&](auto i) { + out.f6_array[i] = utils::sat_convert_to_type_sr(in.float_array[i] / scale, rng); + }); + + return out.f6_vector; +#endif +} + +/** + * @brief Specializes the type conversion template for converting a float into the 6-bit float type + * (f6_t). + * + * Depending on the CK_USE_SR_F6_CONVERSION flag, + * the conversion uses stochastic rounding + * or round-to-nearest-even. + * + * @param x Input float value to be converted. + * @return The converted f6_t value. + */ +template <> +inline __host__ __device__ f6_t type_convert(float x) +{ +#if CK_USE_SR_F6_CONVERSION + return f6_convert_sr(x); +#else + return f6_convert_rne(x); +#endif +} + +/** + * @brief Specializes the type conversion template for converting a vector of 32 floats into the + * vector of 32 6-bit float types (f6x32_t). + * + * Depending on the CK_USE_SR_F6_CONVERSION flag, + * the conversion uses stochastic rounding + * or round-to-nearest-even. + * + * @param x Input float value to be converted. + * @return The converted f6x32_t vector. + */ +template <> +inline __host__ __device__ f6x32_t type_convert(float32_t x) +{ +#if CK_USE_SR_F6_CONVERSION + return f6_convert_sr(x); +#else + return f6_convert_rne(x); +#endif +} + +/** + * @brief Specializes the type conversion template for converting the 6-bit float type (f6_t) to + * float. + * + * Interprets an f6_t value as a float using the default scale factor of 1. + * + * @param x The 6-bit float (f6_t) value to be converted. + * @return The corresponding float representation. + */ +template <> +inline __host__ __device__ float type_convert(f6_t x) +{ +#if defined(__gfx950__) + union + { + f6x32_t f6_vector; + f6_t f6_array[32]; + } in{x}; + + union + { + float32_t float_vector; + float float_array[32]; + } out{}; + + out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6( + in.f6_vector, type_convert(NumericLimits::Binary_1())); + return out.float_array[0]; +#else + return utils::to_float(NumericLimits::Binary_1(), x); +#endif +} + +/** + * @brief Specializes the type conversion template for converting the vector of 32 6-bit float types + * (f6x32_t) to vector of 32 floats. + * + * Interprets an f6_t values as floats using the default scale factor of 1. + * + * @param x The vector of 32 6-bit float (f6x32_t) values to be converted. + * @return The corresponding float representation. + */ +template <> +inline __host__ __device__ float32_t type_convert(f6x32_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6( + x, type_convert(NumericLimits::Binary_1())); +#else + union + { + f6x32_t f6_vector; + f6_t f6_array[32]; + } in{x}; + + union + { + float32_t float_vector; + float float_array[32]; + } out{}; + + ck::static_for<0, 32, 1>{}([&](auto i) { + out.float_array[i] = + utils::to_float(NumericLimits::Binary_1(), in.f6_array[i]); + }); + + return out.float_vector; +#endif +} + +/** + * @brief Converts a float to the 6-bit BF6 type using round-to-nearest-even. + * + * Divides the input by the specified scale, then saturates and converts + * it to a 6-bit BF6 floating-point format. + * + * @param x The float value to be converted. + * @param scale The scaling factor applied to the input before conversion. + * @return The converted bf6_t value. + */ +inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f) +{ +#if defined(__gfx950__) + float16_t in1{x}; + float16_t in2{}; + + union + { + bf6x32_t bf6_vector; + bf6_t bf6_array[32]; + } out{}; + + out.bf6_vector = __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale); + + return out.bf6_array[0]; +#else + return utils::sat_convert_to_type(x / scale); +#endif +} + +/** + * @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using + * round-to-nearest-even. + * + * Divides the input by the specified scale, then saturates and converts + * it to a 6-bit BF6 floating-point format. + * + * @param x The float vector to be converted. + * @param scale The scaling factor applied to the input before conversion. + * @return The converted bf6x32_t vector. + */ +inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1.0f) +{ +#if defined(__gfx950__) + float16_t* in1 = reinterpret_cast(&x); + float16_t* in2 = reinterpret_cast(&x + 16); + return __builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale); +#else + union + { + float32_t float_vector; + float float_array[32]; + } in{x}; + + union + { + bf6x32_t bf6_vector; + bf6_t bf6_array[32]; + } out{}; + + ck::static_for<0, 32, 1>{}([&](auto i) { + out.bf6_array[i] = utils::sat_convert_to_type(in.float_array[i] / scale); + }); + + return out.bf6_vector; +#endif +} + +/** + * @brief Converts a float to the 6-bit BF6 type using stochastic rounding. + * + * Divides the input by the specified scale, + * and converts the result to a 6-bit BF6 floating-point + * format with stochastic rounding. + * + * @param x The float value to be converted. + * @param scale The scaling factor applied to the input before conversion. + * @return The converted bf6_t value. + */ +inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) +{ + constexpr int seed = 1254739; + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#if defined(__gfx950__) + union + { + float32_t float_vector; + float float_array[32]; + } in{x}; + + union + { + bf6x32_t bf6_vector; + bf6_t bf6_array[32]; + } out{}; + + out.bf6_vector = __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale); + + return out.bf6_array[0]; +#else + return utils::sat_convert_to_type_sr(x / scale, rng); +#endif +} + +/** + * @brief Converts a vector of 32 floats to the vector of 32 6-bit BF6 types using stochastic + * rounding. + * + * Divides the input by the specified scale, + * and converts the result to a 6-bit BF6 floating-point + * format with stochastic rounding. + * + * @param x The float vector to be converted. + * @param scale The scaling factor applied to the input before conversion. + * @return The converted bf6x32_t vector. + */ +inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f) +{ + constexpr int seed = 1254739; + union + { + float32_t float_vector; + float float_array[32]; + } float_values{x}; + uint32_t rng = + prand_generator(reinterpret_cast(&x), float_values.float_array[0]); +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale); +#else + union + { + float32_t float_vector; + float float_array[32]; + } in{x}; + + union + { + bf6x32_t bf6_vector; + bf6_t bf6_array[32]; + } out{}; + + ck::static_for<0, 32, 1>{}([&](auto i) { + out.bf6_array[i] = utils::sat_convert_to_type_sr(in.float_array[i] / scale, rng); + }); + + return out.bf6_vector; +#endif +} + +/** + * @brief Specializes float-to-bf6_t conversion. + * + * Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined, + * otherwise uses round-to-nearest-even. + * + * @param x Input float value to convert. + * @return Converted bf6_t value. + */ +template <> +inline __host__ __device__ bf6_t type_convert(float x) +{ +#if CK_USE_SR_F6_CONVERSION + return bf6_convert_sr(x); +#else + return bf6_convert_rne(x); +#endif +} + +/** + * @brief Specializes vector of 32 float-to-bf6_t conversion. + * + * Uses stochastic rounding if CK_USE_SR_F6_CONVERSION is defined, + * otherwise uses round-to-nearest-even. + * + * @param x Input float vector to convert. + * @return Converted bf6x32_t vector. + */ +template <> +inline __host__ __device__ bf6x32_t type_convert(float32_t x) +{ +#if CK_USE_SR_F6_CONVERSION + return bf6_convert_sr(x); +#else + return bf6_convert_rne(x); +#endif +} + +/** + * @brief Specializes the type conversion template for converting a bf6_t value to float. + * + * Interprets the bf6_t value using the default scale factor of 1 and returns + * its floating-point representation. + * + * @param x The bf6_t value to convert. + * @return The float representation of the given bf6_t value. + */ +template <> +inline __host__ __device__ float type_convert(bf6_t x) +{ +#if defined(__gfx950__) + union + { + bf6x32_t bf6_vector; + bf6_t bf6_array[32]; + } in{x}; + + union + { + float32_t float_vector; + float float_array[32]; + } out{}; + + out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6( + in.bf6_vector, type_convert(NumericLimits::Binary_1())); + return out.float_array[0]; +#else + return utils::to_float(NumericLimits::Binary_1(), x); +#endif +} + +/** + * @brief Specializes the type conversion template for converting a vector of 32 bf6_t values to + * vector of 32 floats. + * + * Interprets the bf6x32_t value using the default scale factor of 1 and returns + * its floating-point representation. + * + * @param x The bf6x32_t value to convert. + * @return The float representation of the given vector. + */ +template <> +inline __host__ __device__ float32_t type_convert(bf6x32_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6( + x, type_convert(NumericLimits::Binary_1())); +#else + union + { + bf6x32_t bf6_vector; + bf6_t bf6_array[32]; + } in{x}; + + union + { + float32_t float_vector; + float float_array[32]; + } out{}; + + ck::static_for<0, 32, 1>{}([&](auto i) { + out.float_array[i] = + utils::to_float(NumericLimits::Binary_1(), in.bf6_array[i]); + }); + + return out.float_vector; +#endif +} + #ifndef CK_CODE_GEN_RTC template inline __host__ __device__ void array_convert(std::array& y, diff --git a/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc b/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc index bf895f67c5..d78f266bfd 100644 --- a/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc +++ b/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc @@ -824,4 +824,4 @@ #undef _UK_PK_CVT_ #undef _UK_ATOMIC_ADD_ #undef CK_TILE_FLATMM_UK_MFMA -// clang-format on + // clang-format on diff --git a/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc b/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc index f5e491c3c8..733afdbe94 100644 --- a/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc +++ b/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc @@ -722,4 +722,4 @@ #undef _UK_PK_CVT_ #undef _UK_ATOMIC_ADD_ #undef CK_TILE_FLATMM_UK_MFMA -// clang-format on + // clang-format on diff --git a/include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc b/include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc index 7fa89d9d25..10531b7a26 100644 --- a/include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc +++ b/include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc @@ -771,4 +771,4 @@ #undef _UK_MFMA_ #undef CK_TILE_FLATMM_UK_2B #undef CK_TILE_FLATMM_UK_MFMA -// clang-format on + // clang-format on diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index bea22da2c2..fde946b67c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -41,13 +41,16 @@ template using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances = std::tuple< - // clang-format off +// clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; template using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances = std::tuple< - // clang-format off +// clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, @@ -72,6 +77,7 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; @@ -106,13 +112,16 @@ template using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances = std::tuple< - // clang-format off +// clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; template using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances = std::tuple< - // clang-format off +// clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, @@ -137,6 +148,7 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instance DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; @@ -171,13 +183,16 @@ template using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances = std::tuple< - // clang-format off +// clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; // NGCHW requires transpose, we use vector loads and stores params for them @@ -189,11 +204,13 @@ template using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances = std::tuple< - // clang-format off +// clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 2>, @@ -217,6 +234,7 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; @@ -229,13 +247,16 @@ template using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances = std::tuple< - // clang-format off +// clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; template using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances = std::tuple< - // clang-format off +// clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, BF16, BF16, 2, 2>, @@ -274,6 +297,7 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instance DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 4, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8, 1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp index dc4ee534b9..410abe366c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -56,11 +56,13 @@ template using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< - // clang-format off +// clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx950__) +#else // Compute friendly DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -79,7 +81,7 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> - +#endif // defined(__gfx950__) // clang-format on >; @@ -90,11 +92,13 @@ template using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< - // clang-format off +// clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx950__) +#else DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -109,6 +113,7 @@ using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(__gfx950__) // clang-format on >; @@ -138,11 +143,13 @@ template using device_grouped_conv_fwd_xdl_int8_comp_instances = std::tuple< - // clang-format off +// clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx950__) +#else DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, @@ -153,6 +160,7 @@ using device_grouped_conv_fwd_xdl_int8_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(__gfx950__) // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 242ad2f730..09a489cd04 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -40,15 +40,18 @@ template using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple< - // clang-format off +// clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx950__) +#else // Instances with NumGroupsPerBatch > 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> +#endif // defined(__gfx950__) // clang-format on >; @@ -59,15 +62,18 @@ template using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple< - // clang-format off +// clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx950__) +#else // Instances with NumGroupsPerBatch > 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> +#endif // defined(__gfx950__) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 2bb6ab4c18..5b88d5f25c 100755 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -69,7 +69,7 @@ function(add_instance_library INSTANCE_NAME) endforeach() # Do not build mha instances if gfx94 or gfx90a targets are not on the target list foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha") + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "mha") message("removing mha instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -77,25 +77,25 @@ function(add_instance_library INSTANCE_NAME) # Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "gemm_multiply_multiply_xdl_f8") + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply_xdl_f8") message("removing gemm_multiply_multiply_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "gemm_xdl_universal" AND source MATCHES "_f8_") + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_xdl_universal" AND source MATCHES "_f8_") message("removing gemm_universal_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "batched_gemm_xdl_universal" AND source MATCHES "_f8_") + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "batched_gemm_xdl_universal" AND source MATCHES "_f8_") message("removing batched_gemm_universal_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "gemm_xdl_universal_streamk" AND source MATCHES "_f8_") + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_xdl_universal_streamk" AND source MATCHES "_f8_") message("removing gemm_universal_streamk_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -109,7 +109,7 @@ function(add_instance_library INSTANCE_NAME) if(source MATCHES "_xdl") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(source MATCHES "_wmma") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) elseif(source MATCHES "mha") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() @@ -368,7 +368,7 @@ if(CK_DEVICE_CONV_INSTANCES) endif() if(CK_DEVICE_MHA_INSTANCES) set(gpu_list ${INST_TARGETS}) - if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a") + if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a" OR gpu_list MATCHES "gfx95") add_library(device_mha_operations ${CK_DEVICE_MHA_INSTANCES}) set_target_properties(device_mha_operations PROPERTIES diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp index 5a9483b309..6f10205d3d 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp @@ -27,12 +27,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< - // clang-format off +// clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, @@ -65,6 +68,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp index 0fa0719237..6060cac0c8 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp @@ -27,12 +27,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< - // clang-format off +// clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, @@ -65,6 +68,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp index dc7de8c689..8f9eb92a7f 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp @@ -26,23 +26,30 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_generic_instances = std::tuple< - // clang-format off +// clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#endif // clang-format on >; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< - // clang-format off +// clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, @@ -102,6 +109,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp index cccad7ca19..545986f8b0 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp @@ -26,23 +26,30 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_generic_instances = std::tuple< - // clang-format off +// clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#endif // clang-format on >; // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< - // clang-format off +// clang-format off //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, @@ -90,6 +97,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index cf23d01bf2..c971a55363 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -26,18 +26,22 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Scale = ck::tensor_operation::element_wise::Scale; +#if !defined(CK_USE_AMD_MFMA_GFX950) static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +#endif // c[g, m, n] = a[g, m, k] * b[g, n, k] template using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple< - // clang-format off +// clang-format off //#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut| //#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper| //#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle| //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, @@ -53,24 +57,28 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_ // Padded fallback kernel DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; template using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances = std::tuple< - // clang-format off +// clang-format off //#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskOut| //#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Upper| //#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Triangle| //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking>, DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, Masking> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp index 498bf58fb3..6abb90a39b 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp @@ -26,10 +26,12 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; +#if !defined(CK_USE_AMD_MFMA_GFX950) static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; +#endif // c[g, m, n] = a[g, m, k] * b[g, n, k] template using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances = std::tuple< - // clang-format off +// clang-format off // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| D0s Bias| // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | SrcScalar| // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | PerVector| // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, @@ -62,7 +66,8 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_ DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 744bd6456d..608443cf44 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -26,10 +26,12 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; +#if !defined(CK_USE_AMD_MFMA_GFX950) static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; +#endif // c[g, m, n] = a[g, m, k] * b[g, n, k] template using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple< - // clang-format off +// clang-format off // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| D0s Bias| // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | SrcScalar| // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | PerVector| // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, @@ -64,7 +68,8 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16 DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp index b342612d1c..07eb744323 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp @@ -26,10 +26,12 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Scale = ck::tensor_operation::element_wise::Scale; +#if !defined(CK_USE_AMD_MFMA_GFX950) static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; +#endif // c[g, m, n] = a[g, m, k] * b[g, n, k] template using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances = std::tuple< - // clang-format off +// clang-format off // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, @@ -60,7 +64,8 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_ DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, // Padded fallback kernel DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 3fd0c07370..09055ea188 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -26,10 +26,12 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Scale = ck::tensor_operation::element_wise::Scale; +#if !defined(CK_USE_AMD_MFMA_GFX950) static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; +#endif // c[g, m, n] = a[g, m, k] * b[g, n, k] template using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple< - // clang-format off +// clang-format off // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, @@ -62,7 +66,8 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_ DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, // Padded fallback kernel DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp index da96c79a6e..d30b93cf6a 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp @@ -40,11 +40,12 @@ static constexpr auto ConvFwdOddC = // arbitrary conv using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances = std::tuple< - // clang-format off +// clang-format off //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -58,16 +59,20 @@ using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances = std::tuple< DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> +#else + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> +#endif // clang-format on >; // 1x1, pad 0 using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< - // clang-format off +// clang-format off //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -81,16 +86,20 @@ using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std: DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> +#else + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> +#endif // clang-format on >; // 1x1, stride 1, pad 0 using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< - // clang-format off +// clang-format off //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -104,15 +113,19 @@ using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = s DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> +#else + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> +#endif // clang-format on >; using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< - // clang-format off +// clang-format off //##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -131,6 +144,9 @@ using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances = std:: DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> +#else + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> +#endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp index e34ea06ff4..8d57edf4d7 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp @@ -39,11 +39,12 @@ static constexpr auto ConvFwdOddC = // arbitrary conv using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances = std::tuple< - // clang-format off +// clang-format off //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -57,16 +58,20 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances = s DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> +#else + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> +#endif // clang-format on >; // 1x1, pad 0 using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< - // clang-format off +// clang-format off //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -80,16 +85,20 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instan DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> +#else + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> +#endif // clang-format on >; // 1x1, stride 1, pad 0 using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< - // clang-format off +// clang-format off //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -103,16 +112,20 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_ins DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> +#else + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> +#endif // clang-format on >; // Odd C using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< - // clang-format off +// clang-format off //##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -131,6 +144,9 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instanc DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> +#else + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> +#endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp index 3254fcfc26..63ca21ff74 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp @@ -37,11 +37,12 @@ static constexpr auto ConvFwdOddC = // arbitrary conv using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances = std::tuple< - // clang-format off +// clang-format off //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -55,16 +56,20 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> +#else + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> +#endif // clang-format on >; // 1x1, pad 0 using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< - // clang-format off +// clang-format off //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -78,16 +83,20 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_in DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> +#else + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> +#endif // clang-format on >; // 1x1, stride 1, pad 0 using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< - // clang-format off +// clang-format off //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -101,16 +110,20 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16 DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> +#else + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> +#endif // clang-format on >; // Odd C using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple< - // clang-format off +// clang-format off //##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, @@ -129,6 +142,9 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_ins DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> +#else + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> +#endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp index 2e884dfc8a..2189a10e48 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -29,12 +29,15 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple< - // clang-format off +// clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | Version| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -82,6 +85,7 @@ using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tu DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // !defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp index 60d4ccf525..ce58469697 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -55,6 +55,11 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#if defined(CK_USE_AMD_MFMA_GFX950) + , + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 64, 128, 32, 32, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, PipelineVersion::v1> +#endif #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves , diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp index 44b6848233..73fcc576a1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -31,12 +31,15 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // Compilation parameters for a[k, m] * b[k, n] = c[m, n] using device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< - // clang-format off +// clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -93,6 +96,7 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // !defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp index 23176269c2..be5751e065 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -31,12 +31,15 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // Compilation parameters for a[k, m] * b[n, k] = c[m, n] using device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< - // clang-format off +// clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -93,6 +96,7 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // !defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp index 8c9a96f6b7..c9221d194a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -35,22 +35,27 @@ static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecializati static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_generic_instances = std::tuple< - // clang-format off +// clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] template using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< - // clang-format off +// clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) // pipeline v1, 1 wave DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -109,6 +114,7 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // !defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp index b591dacff5..beef06c9f4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -34,23 +34,30 @@ static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecializa // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instances = std::tuple< - // clang-format off +// clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) + //DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + //DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 128, 32, 32, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; template // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< - // clang-format off +// clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !defined(CK_USE_AMD_MFMA_GFX950) // pipeline v1, 1 wave DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -100,6 +107,7 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // !defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; @@ -110,7 +118,6 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( { add_device_operation_instances( instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instances{}); - add_device_operation_instances( instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp index 0ade7a61ce..3fdf6e23ce 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -47,6 +47,11 @@ using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16> +#if defined(CK_USE_AMD_MFMA_GFX950) + , + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 256, 64, 64, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, PipelineVersion::v1> +#endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp index a64424e8ac..84e6410bfb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -18,6 +17,8 @@ using Instances = //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, @@ -26,9 +27,10 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp index 0a0406baec..f9c9a4b876 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp @@ -17,7 +17,10 @@ using Instances = std::tuple< //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) #endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp index 3671bea7a3..1a6bfef326 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -18,6 +17,8 @@ using Instances = //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, @@ -26,9 +27,10 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp index 95fc8ecb46..2aaee9ab26 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp @@ -17,7 +17,10 @@ using Instances = std::tuple< //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) #endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp index fa53a3bf0f..d14137b56c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -18,6 +17,8 @@ using Instances = //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, @@ -35,9 +36,10 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp index c9d1913aec..a0b0ba5018 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp @@ -17,7 +17,10 @@ using Instances = std::tuple< //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) #endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp index 0410eabb70..4437150047 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp @@ -17,6 +17,8 @@ using Instances = std::tuple< //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, @@ -30,6 +32,7 @@ using Instances = std::tuple< DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) #endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp index a41919aab7..694b06f0b3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp @@ -17,9 +17,12 @@ using Instances = std::tuple< //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) #endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp index 125fbc21a2..77bdb78401 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp @@ -38,23 +38,30 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[k, m], b[k, n], d0[m, n], d1[m, n] using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_generic_instance = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -111,18 +118,22 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; // irregular tile size using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -134,7 +145,8 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn , DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp index cc33692d73..1c93239af6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp @@ -38,23 +38,30 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[k, m], b[n, k], d0[m, n], d1[m, n] using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_generic_instance = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -111,18 +118,22 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; // irregular tile size using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -134,7 +145,8 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn , DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp index 704787a080..162b5a0838 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp @@ -38,23 +38,30 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[m, k], b[k, n], d0[m, n], d1[m, n] using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_generic_instance = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 64, 16, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -111,18 +118,22 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; // irregular tile size using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -134,7 +145,8 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn , DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp index d64c9ec5e1..7e3373e370 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp @@ -38,23 +38,30 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[m, k], b[n, k], d0[m, n], d1[m ,n] using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_generic_instance = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#endif // clang-format on >; using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -102,18 +109,22 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; // irregular tile size using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -125,7 +136,8 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn , DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp index e68bd8e7e4..245aa28fc6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp @@ -23,22 +23,29 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[k, m], b[k, n], d0[m, n] using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_generic_instance = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -95,18 +102,22 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instanc DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // irregular tile size using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -118,7 +129,8 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_irregul , DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp index 5aaa2e8fe5..5ac4a8c10d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp @@ -23,22 +23,29 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[k, m], b[n, k], d0[m, n], d1[m, n] using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_generic_instance = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -95,18 +102,22 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instanc DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // irregular tile size using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -118,7 +129,8 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_irregul , DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index 7a2a3dbaf3..22c23c1bfe 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -23,22 +23,29 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[m, k], b[k, n], d0[m, n], d1[m, n] using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instance = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -95,18 +102,22 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instanc DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // irregular tile size using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -118,7 +129,8 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_irregul , DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp index fa33609978..d4849138d0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp @@ -23,22 +23,29 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[m, k], b[n, k], d0[m, n], d1[m ,n] using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_generic_instance = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -86,18 +93,22 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instanc DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // irregular tile size using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -109,7 +120,8 @@ using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_irregul , DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp index 28a452c1a1..1707c34c0e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp @@ -28,8 +28,10 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; +#if !defined(CK_USE_AMD_MFMA_GFX950) static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +#endif // e = elementwise((a * b), d0, d1) // h = layernorm(e, gamma, beta) @@ -37,11 +39,13 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] template using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances = std::tuple< - // clang-format off +// clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, @@ -58,17 +62,20 @@ using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instan DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // irregular tile size using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // pipeline v1, 1 wave DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES @@ -81,7 +88,8 @@ using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_irregu , DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp index 13366238d6..ed10df8a65 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp @@ -28,20 +28,23 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; +#if !defined(CK_USE_AMD_MFMA_GFX950) static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - +#endif // e = elementwise((a * b), d0, d1) // h = layernorm(e, gamma, beta) // outout: h[m, n] // input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] template using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances = std::tuple< - // clang-format off +// clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, @@ -58,17 +61,20 @@ using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instan DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // irregular tile size using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // pipeline v1, 1 wave DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES @@ -81,7 +87,8 @@ using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_irregu , DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp index 8a4889ee83..5b57e3b249 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp @@ -28,20 +28,23 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; +#if !defined(CK_USE_AMD_MFMA_GFX950) static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - +#endif // e = elementwise((a * b), d0, d1) // h = layernorm(e, gamma, beta) // outout: h[m, n] // input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] template using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances = std::tuple< - // clang-format off +// clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, @@ -58,17 +61,20 @@ using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instan DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // irregular tile size using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // pipeline v1, 1 wave DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES @@ -81,7 +87,8 @@ using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregu , DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp index fc3cbcf905..fe372e5513 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp @@ -28,8 +28,10 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; +#if !defined(CK_USE_AMD_MFMA_GFX950) static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +#endif // e = elementwise((a * b), d0, d1) // h = layernorm(e, gamma, beta) @@ -37,11 +39,13 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] template using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances = std::tuple< - // clang-format off +// clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, @@ -55,17 +59,20 @@ using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instan DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>, DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // irregular tile size using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // pipeline v1, 1 wave DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES @@ -78,7 +85,8 @@ using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregu , DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on >; void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp index 803c44c7f5..16c64cd276 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -22,22 +22,29 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // outout: e[m, n] // input: a[k, m], b[k, n] using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_generic_instance = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -94,17 +101,21 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::t DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // irregular tile size using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -116,6 +127,7 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_irregular_tile_ins , DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp index 9b9ef3db24..1bdfebe977 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -22,22 +22,29 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // outout: e[m, n] // input: a[k, m], b[k, n] using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_generic_instance = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -94,17 +101,21 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::t DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // irregular tile size using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -116,6 +127,7 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_irregular_tile_ins , DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp index 1a0b6c9d1c..86f7eac1e2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -22,22 +22,29 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // outout: e[m, n] // input: a[m, k], b[k, n] using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_generic_instance = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#endif // clang-format on >; using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -94,17 +101,21 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::t DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // irregular tile size using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -116,6 +127,7 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_irregular_tile_ins , DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp index 18b1c0e993..891fafb3ce 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -22,22 +22,29 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // outout: e[m, n] // input: a[m, k], b[n, k] using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instance = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#endif // clang-format on >; using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, @@ -85,17 +92,21 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::t DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; // irregular tile size using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // pipeline v1, 1 wave +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#else DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -107,6 +118,7 @@ using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_irregular_tile_ins , DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> #endif +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp index efc7a7ebfd..ae1b7293f5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp @@ -33,11 +33,13 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // Compilation parameters for a[m, k] * b[k, n] = c[m, n] template using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_iw_instances = std::tuple< - // clang-format off +// clang-format off //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else //PipelineVersion::v1; interwave DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, @@ -57,6 +59,7 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_iw_instances = std::tuple< DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp index 6a323d323f..04d16d8a06 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp @@ -34,11 +34,13 @@ template using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple< - // clang-format off +// clang-format off //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, @@ -63,6 +65,7 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances = std::tup DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp index 2855235f97..e7697d6b07 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp @@ -33,12 +33,14 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // Compilation parameters for a[m, k] * b[k, n] = c[m, n] template using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_iw_instances = std::tuple< - // clang-format off +// clang-format off //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //PipelineVersion::v1; interwave +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, @@ -52,6 +54,7 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_iw_instances = std::tuple< DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave> +#endif // !defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp index 3300c4b0f7..01c4a06dcb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -36,12 +36,13 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -54,17 +55,19 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; template using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Latency friendly DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -83,6 +86,7 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tup DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp index d7b0051186..aa058b2d69 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -36,12 +36,13 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Compute friendly DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 8, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -59,17 +60,19 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; template using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Latency friendly DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 8, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -89,6 +92,7 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances = std::tup DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 8, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 8, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 9566d5555a..1db09e13cb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -36,12 +36,13 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -52,17 +53,19 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -81,6 +84,7 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances = std::tup DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp index 72162b65d3..b015da393f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -36,12 +36,13 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Compute friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -56,17 +57,19 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -85,6 +88,7 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances = std::tup DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp index af9494f5a7..a145b938dd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp @@ -34,12 +34,13 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -54,17 +55,19 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; template using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -95,6 +98,7 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp index f9d693f456..be1dde69db 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp @@ -34,12 +34,13 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Compute friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -66,17 +67,19 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; template using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -111,6 +114,7 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 2, 2, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp index f2eb52b49a..1cef1f49f2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp @@ -35,12 +35,13 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, // Disable due to test failure @@ -50,17 +51,19 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; template using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -73,6 +76,7 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp index 78d16670cb..26e630eb5d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp @@ -35,29 +35,32 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Compute friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; template using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Latency friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -78,6 +81,7 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 16, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 83a898df61..7dc1b701fd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -39,12 +39,13 @@ template , typename DsDataType = ck::Tuple<>> using device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, @@ -55,6 +56,7 @@ using device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances = DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; @@ -63,12 +65,13 @@ template , typename DsDataType = ck::Tuple<>> using device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Latency friendly DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -91,6 +94,7 @@ using device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_instances = s DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp index cfdf5fb67c..617d5f49a5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn.hpp @@ -40,12 +40,15 @@ template , typename DsDataType = ck::Tuple<>> using device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| DsLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) + //DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 16, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v2> +#else DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, @@ -54,6 +57,7 @@ using device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_instances = st DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; @@ -62,12 +66,14 @@ template , typename DsDataType = ck::Tuple<>> using device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| DsLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) + //DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 16, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#else // Latency friendly DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -80,6 +86,7 @@ using device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_instances = std DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp index a81b5c419e..7ef8c01729 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp @@ -39,12 +39,13 @@ template , typename DsDataType = ck::Tuple<>> using device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| DsLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, @@ -55,6 +56,7 @@ using device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances = std DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; @@ -63,12 +65,13 @@ template , typename DsDataType = ck::Tuple<>> using device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| DsLayout| CLayout|AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Latency friendly DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -91,6 +94,7 @@ using device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_instances = std: DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp index 5460f7f857..aa193417d9 100755 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp @@ -34,12 +34,14 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(__gfx950__) + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 128, 16, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v2> +#else DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -52,20 +54,23 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(__gfx950__) // clang-format on >; template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - - // Latency friendly +#if defined(__gfx950__) +#else + // Latency friendly DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -95,6 +100,7 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(__gfx950__) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp index e716b3e85c..938eea8ef5 100755 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -34,12 +34,13 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - +#if defined(CK_USE_AMD_MFMA_GFX950) +#else // Compute friendly DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -64,18 +65,20 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< - // clang-format off +// clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - - // Latency friendly +#if defined(CK_USE_AMD_MFMA_GFX950) +#else + // Latency friendly DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -109,6 +112,7 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 2, 2, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp index f1c1cfc391..eeaf6c048d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp @@ -27,15 +27,18 @@ using S = ck::Sequence; using Empty_Tuple = ck::Tuple<>; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +#if !defined(CK_USE_AMD_MFMA_GFX950) static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - +#endif using device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_tile_instances = std::tuple< - // clang-format off +// clang-format off //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) +#else DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F8, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F8, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F8, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, @@ -98,6 +101,7 @@ using device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_tile_instance DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F8, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F8, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F8, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp index d943376a37..e7c02805af 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp @@ -44,11 +44,14 @@ template using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +// clang-format off //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#else // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, @@ -57,6 +60,7 @@ using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = s // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; @@ -65,13 +69,14 @@ template -using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = - std::tuple< - // clang-format off +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple< +// clang-format off //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -84,8 +89,9 @@ using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> - // clang-format on - >; +#endif // defined(CK_USE_AMD_MFMA_GFX950) + // clang-format on + >; } // namespace instance } // namespace device diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 29a216c704..5de59ee5a3 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -100,7 +100,7 @@ function(add_test_executable TEST_NAME) if(ARGN MATCHES "_xdl") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(ARGN MATCHES "_wmma") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) elseif(ARGN MATCHES "_smfmac") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() @@ -169,26 +169,38 @@ function(add_gtest_executable TEST_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + foreach(source IN LISTS ARGN) if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") message("removing xdl test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + + foreach(source IN LISTS ARGN) + if(NOT TEST_TARGETS MATCHES "gfx95" AND source MATCHES "mx_") + message("removing microscaling test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + foreach(source IN LISTS ARGN) if(NOT TEST_TARGETS MATCHES "gfx11" AND NOT TEST_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + #only continue if there are some source files left on the list if(ARGN) if(ARGN MATCHES "_xdl") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(ARGN MATCHES "_wmma") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950) elseif(ARGN MATCHES "_smfmac") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + elseif(ARGN MATCHES "_mx") #only build mx example for gfx950 + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) @@ -258,8 +270,11 @@ add_subdirectory(wrapper) if(SUPPORTED_GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() -if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2 +if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") # smfmac needs ROCm6.2 add_subdirectory(smfmac_op) endif() +if(SUPPORTED_GPU_TARGETS MATCHES "gfx950") + add_subdirectory(mx_mfma_op) +endif() add_subdirectory(position_embedding) add_subdirectory(scatter_gather) diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 3b1dfecb48..58d8768736 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -9,11 +9,10 @@ if (USE_BITINT_EXTENSION_INT4) endif() endif() - - add_custom_target(test_fp8) if (CK_USE_OCP_FP8) + # add test for ocp data types add_gtest_executable(test_fp8_ocp test_fp8_ocp.cpp) if(result EQUAL 0) target_link_libraries(test_fp8_ocp PRIVATE utility) @@ -43,6 +42,45 @@ if (CK_USE_FNUZ_FP8) add_dependencies(test_fp8 test_bf8_fnuz) endif() +if(GPU_TARGETS MATCHES "gfx950") + add_custom_target(test_mx_data_types) + + add_gtest_executable(test_fp4 test_fp4.cpp) + if(result EQUAL 0) + target_link_libraries(test_fp4 PRIVATE utility) + endif() + add_dependencies(test_mx_data_types test_fp4) + + add_gtest_executable(test_fp6 test_fp6.cpp) + if(result EQUAL 0) + target_link_libraries(test_fp6 PRIVATE utility) + endif() + add_dependencies(test_mx_data_types test_fp6) + + add_gtest_executable(test_bf6 test_bf6.cpp) + if(result EQUAL 0) + target_link_libraries(test_bf6 PRIVATE utility) + endif() + add_dependencies(test_mx_data_types test_bf6) + + add_gtest_executable(test_mx_fp8 test_mx_fp8.cpp) + if(result EQUAL 0) + target_link_libraries(test_mx_fp8 PRIVATE utility) + endif() + add_dependencies(test_mx_data_types test_mx_fp8) + + add_gtest_executable(test_mx_bf8 test_mx_bf8.cpp) + if(result EQUAL 0) + target_link_libraries(test_mx_bf8 PRIVATE utility) + endif() + add_dependencies(test_mx_data_types test_mx_bf8) + + add_gtest_executable(test_e8m0 test_e8m0.cpp) + if(result EQUAL 0) + target_link_libraries(test_e8m0 PRIVATE utility) + endif() + add_dependencies(test_mx_data_types test_e8m0) +endif() add_gtest_executable(test_custom_type test_custom_type.cpp) if(result EQUAL 0) target_link_libraries(test_custom_type PRIVATE utility) diff --git a/test/data_type/test_bf6.cpp b/test/data_type/test_bf6.cpp new file mode 100644 index 0000000000..a260f81d16 --- /dev/null +++ b/test/data_type/test_bf6.cpp @@ -0,0 +1,388 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" +#include "ck/utility/scaled_type_convert.hpp" + +using ck::bf6_convert_rne; +using ck::bf6_convert_sr; +using ck::bf6_t; +using ck::bf6x16_pk_t; +using ck::bf6x32_pk_t; +using ck::e8m0_bexp_t; +using ck::Number; +using ck::scaled_type_convert; +using ck::type_convert; +using ck::vector_type; + +TEST(BF6, NumericLimits) +{ + EXPECT_EQ(ck::NumericLimits::Min(), bf6_t(0b001000)); + EXPECT_EQ(ck::NumericLimits::Max(), bf6_t(0b011111)); + EXPECT_EQ(ck::NumericLimits::Lowest(), bf6_t(0b111111)); + EXPECT_EQ(ck::NumericLimits::MinSubnorm(), bf6_t(0b000001)); + EXPECT_EQ(ck::NumericLimits::MaxSubnorm(), bf6_t(0b000011)); +} + +TEST(BF6, ConvertFP32Nearest) +{ + // set maximum bf6 value + float max_bf6 = 28.0f; + // convert 0 float to bf6 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(bf6_convert_rne(0.0f)), 0.0f); + // convert max_bf6 to float and check if equal to max_bf6 + ASSERT_NEAR(max_bf6, type_convert(bf6_convert_rne(max_bf6)), 0.0f); + // convert maximal float to bf6 and back, check if clipped to max_bf6 + ASSERT_NEAR( + max_bf6, type_convert(bf6_convert_rne(std::numeric_limits::max())), 0.0f); + // convert float Inf to bf6 and back, check if clipped to max_bf6 + ASSERT_NEAR(max_bf6, + type_convert(bf6_convert_rne(std::numeric_limits::infinity())), + 0.0f); + // convert float value less than bf6 subnorm to bf6 and back, check if equal to 0.0 + float less_than_subnorm = 0.03125f; + ASSERT_NEAR(0.0f, type_convert(bf6_convert_rne(less_than_subnorm)), 0.0f); + // convert float NaN to bf6 and back, check if clipped to max_bf6 + ASSERT_NEAR(max_bf6, + type_convert(bf6_convert_rne(std::numeric_limits::quiet_NaN())), + 0.0f); + // positive norm float value to bf6 and back, check if holds + float pos_float = 0.25f; + ASSERT_NEAR(pos_float, type_convert(bf6_convert_rne(pos_float)), 0.0f); + // negative norm float value to bf6 and back, check if holds + float neg_float = -0.5f; + ASSERT_NEAR(neg_float, type_convert(bf6_convert_rne(neg_float)), 0.0f); + // positive subnorm float value to bf6 and back, check if holds + pos_float = 0.1875f; + ASSERT_NEAR(pos_float, type_convert(bf6_convert_rne(pos_float)), 0.0f); + // negative subnorm float value to bf6 and back, check if holds + neg_float = -0.0625f; + ASSERT_NEAR(neg_float, type_convert(bf6_convert_rne(neg_float)), 0.0f); +} + +TEST(BF6, ConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // set maximum bf6 value + float max_bf6 = 28.0f; + // convert 0 float to bf6 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(bf6_convert_sr(0.0f)), abs_tol); + // convert maximal bf6_t to float and check if equal to max_bf6 + ASSERT_NEAR(max_bf6, type_convert(bf6_convert_sr(max_bf6)), abs_tol); + // convert maximal float to bf6 and back, check if clipped to max_bf6 + ASSERT_NEAR( + max_bf6, type_convert(bf6_convert_sr(std::numeric_limits::max())), abs_tol); + // convert float Inf to bf6 and back, check if clipped to max_bf6 + ASSERT_NEAR(max_bf6, + type_convert(bf6_convert_rne(std::numeric_limits::infinity())), + 0.0f); + // convert float NaN to bf6 and back, check if clipped to max_bf6 + ASSERT_NEAR(max_bf6, + type_convert(bf6_convert_rne(std::numeric_limits::quiet_NaN())), + 0.0f); + // positive norm float value to bf6 and back, check if holds + float pos_float = 0.25f; + ASSERT_NEAR(pos_float, type_convert(bf6_convert_sr(pos_float)), abs_tol); + // negative norm float value to bf6 and back, check if holds + float neg_float = -0.5f; + ASSERT_NEAR(neg_float, type_convert(bf6_convert_sr(neg_float)), abs_tol); + // positive subnorm float value to bf6 and back, check if holds + pos_float = 0.1875f; + ASSERT_NEAR(pos_float, type_convert(bf6_convert_sr(pos_float)), abs_tol); + // negative subnorm float value to bf6 and back, check if holds + neg_float = -0.0625f; + ASSERT_NEAR(neg_float, type_convert(bf6_convert_sr(neg_float)), abs_tol); +} + +TEST(BF6, ScaledConvertFP32Nearest) +{ + // set maximum scale + float max_scale = type_convert(ck::NumericLimits::Max()); // 0xFE -> float + // set minimum scale + float min_scale = type_convert(ck::NumericLimits::Min()); // 0x00 -> float + // set arbitrary scale to 256.0 + float test_scale = 256.0f; // 0b10000111 + // convert 0 float to bf6 and back with maximal scale, check if holds + ASSERT_NEAR( + 0.0f, scaled_type_convert(e8m0_bexp_t(max_scale), bf6_convert_rne(0.0f)), 0.0f); + // convert 0 float to bf6 and back with minimal scale, check if holds + ASSERT_NEAR( + 0.0f, scaled_type_convert(e8m0_bexp_t(min_scale), bf6_convert_rne(0.0f)), 0.0f); + // positive norm float value to bf6 and back with various scales, check if holds + float pos_float = 0.25f; + ASSERT_NEAR(pos_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), bf6_convert_rne(pos_float)), + 0.0f); + ASSERT_NEAR(pos_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), bf6_convert_rne(pos_float)), + 0.0f); + ASSERT_NEAR(pos_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), bf6_convert_rne(pos_float)), + 0.0f); + // negative norm float value to bf6 and back with various scales, check if holds + float neg_float = -0.5f; + ASSERT_NEAR(neg_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), bf6_convert_rne(neg_float)), + 0.0f); + ASSERT_NEAR(neg_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), bf6_convert_rne(neg_float)), + 0.0f); + ASSERT_NEAR(neg_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), bf6_convert_rne(neg_float)), + 0.0f); + // positive subnorm float value to bf6 and back with various scales, check if holds + pos_float = 0.1875f; + ASSERT_NEAR(pos_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), bf6_convert_rne(pos_float)), + 0.0f); + ASSERT_NEAR(pos_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), bf6_convert_rne(pos_float)), + 0.0f); + ASSERT_NEAR(pos_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), bf6_convert_rne(pos_float)), + 0.0f); + // negative subnorm float value to bf6 and back with various scales, check if holds + neg_float = -0.0625f; + ASSERT_NEAR(neg_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), bf6_convert_rne(neg_float)), + 0.0f); + ASSERT_NEAR(neg_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), bf6_convert_rne(neg_float)), + 0.0f); + ASSERT_NEAR(neg_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), bf6_convert_rne(neg_float)), + 0.0f); +} + +TEST(BF6, ScaledConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // set maximum scale + float max_scale = type_convert(ck::NumericLimits::Max()); // 0xFE -> float + // set minimum scale + float min_scale = type_convert(ck::NumericLimits::Min()); // 0x00 -> float + // set arbitrary scale to 256.0 + float test_scale = 256.0f; // 0b10000111 + // convert 0 float to bf6 and back with maximal scale, check if holds + ASSERT_NEAR( + 0.0f, scaled_type_convert(e8m0_bexp_t(max_scale), bf6_convert_sr(0.0f)), abs_tol); + // convert 0 float to bf6 and back with minimal scale, check if holds + ASSERT_NEAR( + 0.0f, scaled_type_convert(e8m0_bexp_t(min_scale), bf6_convert_sr(0.0f)), abs_tol); + // positive norm float value to bf6 and back with various scales, check if holds + float pos_float = 0.25f; + ASSERT_NEAR(pos_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), bf6_convert_sr(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), bf6_convert_sr(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), bf6_convert_sr(pos_float)), + abs_tol); + // negative norm float value to bf6 and back with various scales, check if holds + float neg_float = -0.5f; + ASSERT_NEAR(neg_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), bf6_convert_sr(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), bf6_convert_sr(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), bf6_convert_sr(neg_float)), + abs_tol); + // positive subnorm float value to bf6 and back with various scales, check if holds + pos_float = 0.1875f; + ASSERT_NEAR(pos_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), bf6_convert_sr(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), bf6_convert_sr(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), bf6_convert_sr(pos_float)), + abs_tol); + // negative subnorm float value to bf6 and back with various scales, check if holds + neg_float = -0.0625f; + ASSERT_NEAR(neg_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), bf6_convert_sr(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), bf6_convert_sr(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), bf6_convert_sr(neg_float)), + abs_tol); +} + +TEST(BF6, TestSize) +{ + ASSERT_EQ(1, sizeof(bf6_t)); + ASSERT_EQ(12, sizeof(bf6x16_pk_t)); + ASSERT_EQ(24, sizeof(bf6x32_pk_t)); + ASSERT_EQ(16, sizeof(vector_type)); + ASSERT_EQ(32, sizeof(vector_type)); + ASSERT_EQ(32, sizeof(vector_type)); +} + +TEST(BF6, TestAlignment) +{ + ASSERT_EQ(1, alignof(bf6_t)); + ASSERT_EQ(4, alignof(bf6x16_pk_t)); + ASSERT_EQ(4, alignof(bf6x32_pk_t)); + ASSERT_EQ(16, alignof(vector_type)); + ASSERT_EQ(32, alignof(vector_type)); + ASSERT_EQ(32, alignof(vector_type)); +} + +// test vector of 1 bf6x16_pk_t, contains 16 bf6_t +TEST(BF6, TestAsType16x1) +{ + // test size + const int vector_size = 1; + const int packed_size = 16; + typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); + test_vec_t test_vec = {bf6_t(0b000000), + bf6_t(0b100000), + bf6_t(0b000001), + bf6_t(0b100001), + bf6_t(0b000010), + bf6_t(0b100010), + bf6_t(0b000011), + bf6_t(0b100011), + bf6_t(0b000100), + bf6_t(0b100100), + bf6_t(0b000101), + bf6_t(0b100101), + bf6_t(0b000110), + bf6_t(0b100110), + bf6_t(0b001011), + bf6_t(0b101011)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, packed_size, 1>{}([&](auto i) { + ASSERT_EQ( + right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), + 0); + }); + // assign test values to the vector + ck::static_for<0, vector_size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = bf6x16_pk_t{}.pack(test_vec); + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, packed_size, 1>{}([&](auto i) { + ASSERT_EQ( + left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), + static_cast(test_vec[static_cast(i)])); + }); +} + +// test vector of 2 bf6x16_pk_t, contains 32 bf6_t +TEST(BF6, TestAsType16x2) +{ + // test size + const int vector_size = 2; + const int packed_size = 16; + typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); + test_vec_t test_vec[2]; + test_vec[0] = {bf6_t(0b000000), + bf6_t(0b100000), + bf6_t(0b000001), + bf6_t(0b100001), + bf6_t(0b000010), + bf6_t(0b100010), + bf6_t(0b000011), + bf6_t(0b100011), + bf6_t(0b000100), + bf6_t(0b100100), + bf6_t(0b000101), + bf6_t(0b100101), + bf6_t(0b000110), + bf6_t(0b100110), + bf6_t(0b001011), + bf6_t(0b101011)}; + test_vec[1] = {bf6_t(0b010000), + bf6_t(0b110000), + bf6_t(0b010001), + bf6_t(0b110001), + bf6_t(0b010010), + bf6_t(0b110010), + bf6_t(0b010011), + bf6_t(0b110011), + bf6_t(0b010100), + bf6_t(0b110100), + bf6_t(0b010101), + bf6_t(0b110101), + bf6_t(0b010110), + bf6_t(0b110110), + bf6_t(0b011011), + bf6_t(0b111011)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { + ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { + ASSERT_EQ(right_vec.template AsType()(Number{}) + .template unpack<>(Number{}), + 0); + }); + }); + // assign test values to the vector + ck::static_for<0, vector_size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = bf6x16_pk_t{}.pack(test_vec[i]); + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { + ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { + ASSERT_EQ(left_vec.template AsType()(Number{}) + .template unpack<>(Number{}), + static_cast(test_vec[idx_vector][static_cast(idx_element)])); + }); + }); +} + +// test vector of 1 bf6x32_pk_t, contains 32 bf6_t +TEST(BF6, TestAsType32x1) +{ + // test size + const int vector_size = 1; + const int packed_size = 32; + typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); + test_vec_t test_vec = {bf6_t(0b000000), bf6_t(0b100000), bf6_t(0b000001), bf6_t(0b100001), + bf6_t(0b000010), bf6_t(0b100010), bf6_t(0b000011), bf6_t(0b100011), + bf6_t(0b000100), bf6_t(0b100100), bf6_t(0b000101), bf6_t(0b100101), + bf6_t(0b000110), bf6_t(0b100110), bf6_t(0b001011), bf6_t(0b101011), + bf6_t(0b010000), bf6_t(0b110000), bf6_t(0b010001), bf6_t(0b110001), + bf6_t(0b010010), bf6_t(0b110010), bf6_t(0b010011), bf6_t(0b110011), + bf6_t(0b010100), bf6_t(0b110100), bf6_t(0b010101), bf6_t(0b110101), + bf6_t(0b010110), bf6_t(0b110110), bf6_t(0b011011), bf6_t(0b111011)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, packed_size, 1>{}([&](auto i) { + ASSERT_EQ( + right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), + 0); + }); + // assign test values to the vector + ck::static_for<0, vector_size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = bf6x32_pk_t{}.pack(test_vec); + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, packed_size, 1>{}([&](auto i) { + ASSERT_EQ( + left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), + static_cast(test_vec[static_cast(i)])); + }); +} diff --git a/test/data_type/test_e8m0.cpp b/test/data_type/test_e8m0.cpp new file mode 100644 index 0000000000..83bb0e38b8 --- /dev/null +++ b/test/data_type/test_e8m0.cpp @@ -0,0 +1,99 @@ +#include +#include "ck/utility/e8m0.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" + +using namespace ck; + +TEST(E8M0, DefaultConstructor) +{ + e8m0_bexp_t exp; + EXPECT_EQ(exp.data, 0); +} + +TEST(E8M0, InitConstructor) +{ + e8m0_bexp_t exp(0x7F); + EXPECT_EQ(exp.data, 0x7F); +} + +TEST(E8M0, FloatConstructor) +{ + e8m0_bexp_t exp(1.0f); + EXPECT_EQ(exp.data, 0x7F); +} + +TEST(E8M0, FloatConstructorNaN) +{ + e8m0_bexp_t exp(std::numeric_limits::quiet_NaN()); + EXPECT_EQ(exp.data, 0xFF); +} + +TEST(E8M0, FloatConstructorZero) +{ + e8m0_bexp_t exp(0.0f); + EXPECT_EQ(exp.data, 0); +} + +TEST(E8M0, ConversionToFloat) +{ + e8m0_bexp_t exp(0x7F); + float value = type_convert(exp); + EXPECT_EQ(value, 1.0f); +} + +TEST(E8M0, ConversionToFloatNaN) +{ + e8m0_bexp_t exp(0xFF); + float value = type_convert(exp); + EXPECT_TRUE(std::isnan(value)); +} + +TEST(E8M0, MinValue) +{ + e8m0_bexp_t exp(0); + EXPECT_TRUE(exp == ck::NumericLimits::Min()); + + float value = type_convert(exp); + EXPECT_EQ(value, std::powf(2, -ck::NumericUtils::bias)); +} + +TEST(E8M0, MaxValue) +{ + e8m0_bexp_t exp(254); + EXPECT_TRUE(exp == ck::NumericLimits::Max()); + + float value = type_convert(exp); + EXPECT_EQ(value, + std::powf(2, + ck::NumericLimits::Max().data - + ck::NumericUtils::bias)); +} + +TEST(E8M0, EqualityOperator) +{ + e8m0_bexp_t exp1(0x7F); + e8m0_bexp_t exp2(0x7F); + EXPECT_TRUE(exp1 == exp2); +} + +TEST(E8M0, InequalityOperator) +{ + e8m0_bexp_t exp1(0x7F); + e8m0_bexp_t exp2(0x80); + EXPECT_FALSE(exp1 == exp2); +} + +TEST(E8M0, EqualityOperatorNaN) +{ + e8m0_bexp_t exp1(0xFF); + e8m0_bexp_t exp2(0xFF); + EXPECT_FALSE(exp1 == exp2); +} + +TEST(E8M0, GetExponentValue) +{ + e8m0_bexp_t exp(0x7F); + int value = ck::utils::get_exponent_value(exp); + EXPECT_EQ(value, 0x7F); +} diff --git a/test/data_type/test_fp4.cpp b/test/data_type/test_fp4.cpp new file mode 100644 index 0000000000..f4b2bf3358 --- /dev/null +++ b/test/data_type/test_fp4.cpp @@ -0,0 +1,470 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" +#include "ck/utility/scaled_type_convert.hpp" + +using ck::e8m0_bexp_t; +using ck::f4_convert_rne; +using ck::f4_convert_sr; +using ck::f4_t; +using ck::f4x2_pk_t; +using ck::Number; +using ck::scaled_type_convert; +using ck::type_convert; +using ck::vector_type; + +TEST(FP4, NumericLimits) +{ + EXPECT_EQ(ck::NumericLimits::Min(), f4_t{0x2}); + EXPECT_EQ(ck::NumericLimits::Max(), f4_t{0x7}); + EXPECT_EQ(ck::NumericLimits::Lowest(), f4_t{0xF}); + EXPECT_EQ(ck::NumericLimits::MinSubnorm(), f4_t{0x1}); + EXPECT_EQ(ck::NumericLimits::MaxSubnorm(), f4_t{0x1}); +} + +TEST(FP4, ConvertFP32Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // set maximum fp4 value + float max_fp4 = 6.0f; + // convert 0 float to fp4 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f4_convert_rne(0.0f)), abs_tol); + // convert maximal f4_t to float and check if equal to 6.0 + ASSERT_NEAR(max_fp4, type_convert(f4_convert_rne(max_fp4)), abs_tol); + // convert maximal float to fp4 and back, check if clipped to 6.0 + ASSERT_NEAR( + max_fp4, type_convert(f4_convert_rne(std::numeric_limits::max())), abs_tol); + // positive norm float value to fp4 and back, check if holds + float pos_float = 1.0f; + ASSERT_NEAR(pos_float, type_convert(f4_convert_rne(pos_float)), abs_tol); + // negative norm float value to fp4 and back, check if holds + float neg_float = -1.5f; + ASSERT_NEAR(neg_float, type_convert(f4_convert_rne(neg_float)), abs_tol); + // positive subnorm float value to fp4 and back, check if holds + pos_float = 0.5f; + ASSERT_NEAR(pos_float, type_convert(f4_convert_rne(pos_float)), abs_tol); + // negative subnorm float value to fp4 and back, check if holds + neg_float = -0.5f; + ASSERT_NEAR(neg_float, type_convert(f4_convert_rne(neg_float)), abs_tol); +} + +TEST(FP4, ConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // set maximum fp4 value + float max_fp4 = 6.0f; + // convert 0 float to fp4 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f4_convert_sr(0.0f)), abs_tol); + // convert maximal f4_t to float and check if equal to 6.0 + ASSERT_NEAR(max_fp4, type_convert(f4_convert_sr(max_fp4)), abs_tol); + // convert maximal float to fp4 and back, check if clipped to 6.0 + ASSERT_NEAR( + max_fp4, type_convert(f4_convert_sr(std::numeric_limits::max())), abs_tol); + // positive norm float value to fp4 and back, check if holds + float pos_float = 1.0f; + ASSERT_NEAR(pos_float, type_convert(f4_convert_sr(pos_float)), abs_tol); + // negative norm float value to fp4 and back, check if holds + float neg_float = -1.5f; + ASSERT_NEAR(neg_float, type_convert(f4_convert_sr(neg_float)), abs_tol); + // positive subnorm float value to fp4 and back, check if holds + pos_float = 0.5f; + ASSERT_NEAR(pos_float, type_convert(f4_convert_sr(pos_float)), abs_tol); + // negative subnorm float value to fp4 and back, check if holds + neg_float = -0.5f; + ASSERT_NEAR(neg_float, type_convert(f4_convert_sr(neg_float)), abs_tol); +} + +TEST(FP4, ScaledConvertFP32Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // set maximum scale + float max_scale = type_convert(ck::NumericLimits::Max()); // 0xFE -> float + // set minimum scale + float min_scale = type_convert(ck::NumericLimits::Min()); // 0x00 -> float + // set arbitrary scale to 256.0 + float test_scale = 256.0f; // 0b10000111 + // convert 0 float to fp4 and back with maximal scale, check if holds + ASSERT_NEAR( + 0.0f, scaled_type_convert(e8m0_bexp_t(max_scale), f4_convert_rne(0.0f)), abs_tol); + // convert 0 float to fp4 and back with minimal scale, check if holds + ASSERT_NEAR( + 0.0f, scaled_type_convert(e8m0_bexp_t(min_scale), f4_convert_rne(0.0f)), abs_tol); + // positive norm float value to fp4 and back with various scales, check if holds + float pos_float = 1.0f; + ASSERT_NEAR(pos_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f4_convert_rne(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f4_convert_rne(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f4_convert_rne(pos_float)), + abs_tol); + // negative norm float value to fp4 and back with various scales, check if holds + float neg_float = -1.5f; + ASSERT_NEAR(neg_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f4_convert_rne(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f4_convert_rne(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f4_convert_rne(neg_float)), + abs_tol); + // positive subnorm float value to fp4 and back with various scales, check if holds + pos_float = 0.5f; + ASSERT_NEAR(pos_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f4_convert_rne(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f4_convert_rne(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f4_convert_rne(pos_float)), + abs_tol); + // negative subnorm float value to fp4 and back with various scales, check if holds + neg_float = -0.5f; + ASSERT_NEAR(neg_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f4_convert_rne(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f4_convert_rne(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f4_convert_rne(neg_float)), + abs_tol); +} + +TEST(FP4, ScaledConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // set maximum scale + float max_scale = type_convert(ck::NumericLimits::Max()); // 0xFE -> float + // set minimum scale + float min_scale = type_convert(ck::NumericLimits::Min()); // 0x00 -> float + // set arbitrary scale to 256.0 + float test_scale = 256.0f; // 0b10000111 + // convert 0 float to fp4 and back with maximal scale, check if holds + ASSERT_NEAR( + 0.0f, scaled_type_convert(e8m0_bexp_t(max_scale), f4_convert_sr(0.0f)), abs_tol); + // convert 0 float to fp4 and back with minimal scale, check if holds + ASSERT_NEAR( + 0.0f, scaled_type_convert(e8m0_bexp_t(min_scale), f4_convert_sr(0.0f)), abs_tol); + // positive norm float value to fp4 and back with various scales, check if holds + float pos_float = 1.0f; + ASSERT_NEAR(pos_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f4_convert_sr(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f4_convert_sr(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f4_convert_sr(pos_float)), + abs_tol); + // negative norm float value to fp4 and back with various scales, check if holds + float neg_float = -1.5f; + ASSERT_NEAR(neg_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f4_convert_sr(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f4_convert_sr(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f4_convert_sr(neg_float)), + abs_tol); + // positive subnorm float value to fp4 and back with various scales, check if holds + pos_float = 0.5f; + ASSERT_NEAR(pos_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f4_convert_sr(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f4_convert_sr(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f4_convert_sr(pos_float)), + abs_tol); + // negative subnorm float value to fp4 and back with various scales, check if holds + neg_float = -0.5f; + ASSERT_NEAR(neg_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f4_convert_sr(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f4_convert_sr(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f4_convert_sr(neg_float)), + abs_tol); +} + +TEST(FP4, TestSize) +{ + ASSERT_EQ(1, sizeof(f4x2_pk_t)); + ASSERT_EQ(1, sizeof(vector_type)); + ASSERT_EQ(2, sizeof(vector_type)); + ASSERT_EQ(4, sizeof(vector_type)); + ASSERT_EQ(8, sizeof(vector_type)); + ASSERT_EQ(16, sizeof(vector_type)); + ASSERT_EQ(32, sizeof(vector_type)); +} + +TEST(FP4, TestAlignment) +{ + ASSERT_EQ(1, alignof(f4x2_pk_t)); + ASSERT_EQ(1, alignof(vector_type)); + ASSERT_EQ(2, alignof(vector_type)); + ASSERT_EQ(4, alignof(vector_type)); + ASSERT_EQ(8, alignof(vector_type)); + ASSERT_EQ(16, alignof(vector_type)); + ASSERT_EQ(32, alignof(vector_type)); +} + +// test vector of 1 f4x2_pk_t, contains 2 f4_t +TEST(FP4, TestAsType1) +{ + // test size + const int size = 1; + std::vector test_vec = {f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ( + right_vec.template AsType()(Number{}).template unpack<>(Number<0>{}), 0); + ASSERT_EQ( + right_vec.template AsType()(Number{}).template unpack<>(Number<1>{}), 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = + f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1)); + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).template unpack<>(Number<0>{}), + test_vec.at(i)); + ASSERT_EQ(left_vec.template AsType()(Number{}).template unpack<>(Number<1>{}), + test_vec.at(i + 1)); + }); +} + +// test vector of 2 f4x2_pk_t, contains 4 f4_t +TEST(FP4, TestAsType2) +{ + // test size + const int size = 2; + std::vector test_vec = {f4x2_pk_t::type{0b0010}, + f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0111}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ( + right_vec.template AsType()(Number{}).template unpack<>(Number<0>{}), 0); + ASSERT_EQ( + right_vec.template AsType()(Number{}).template unpack<>(Number<1>{}), 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = + f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1)); + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).template unpack<>(Number<0>{}), + test_vec.at(i)); + ASSERT_EQ(left_vec.template AsType()(Number{}).template unpack<>(Number<1>{}), + test_vec.at(i + 1)); + }); +} + +// test vector of 4 f4x2_pk_t, contains 8 f4_t +TEST(FP4, TestAsType4) +{ + // test size + const int size = 4; + std::vector test_vec = {f4x2_pk_t::type{0b0010}, + f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0111}, + f4x2_pk_t::type{0b1010}, + f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b1111}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ( + right_vec.template AsType()(Number{}).template unpack<>(Number<0>{}), 0); + ASSERT_EQ( + right_vec.template AsType()(Number{}).template unpack<>(Number<1>{}), 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = + f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1)); + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).template unpack<>(Number<0>{}), + test_vec.at(i)); + ASSERT_EQ(left_vec.template AsType()(Number{}).template unpack<>(Number<1>{}), + test_vec.at(i + 1)); + }); +} + +// test vector of 8 f4x2_pk_t, contains 16 f4_t +TEST(FP4, TestAsType8) +{ + // test size + const int size = 8; + std::vector test_vec = {f4x2_pk_t::type{0b0010}, + f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0111}, + f4x2_pk_t::type{0b1010}, + f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b1111}, + f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0111}, + f4x2_pk_t::type{0b1010}, + f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0010}, + f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b1111}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ( + right_vec.template AsType()(Number{}).template unpack<>(Number<0>{}), 0); + ASSERT_EQ( + right_vec.template AsType()(Number{}).template unpack<>(Number<1>{}), 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = + f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1)); + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).template unpack<>(Number<0>{}), + test_vec.at(i)); + ASSERT_EQ(left_vec.template AsType()(Number{}).template unpack<>(Number<1>{}), + test_vec.at(i + 1)); + }); +} + +// test vector of 16 f4x2_pk_t, contains 32 f4_t +TEST(FP4, TestAsType16) +{ + // test size + const int size = 16; + std::vector test_vec = { + f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, + f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, + f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, + f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ( + right_vec.template AsType()(Number{}).template unpack<>(Number<0>{}), 0); + ASSERT_EQ( + right_vec.template AsType()(Number{}).template unpack<>(Number<1>{}), 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = + f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1)); + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).template unpack<>(Number<0>{}), + test_vec.at(i)); + ASSERT_EQ(left_vec.template AsType()(Number{}).template unpack<>(Number<1>{}), + test_vec.at(i + 1)); + }); +} + +// test vector of 32 f4x2_pk_t, contains 64 f4_t +TEST(FP4, TestAsType32) +{ + // test size + const int size = 32; + std::vector test_vec = { + f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, + f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, + f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, + f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0010}, + f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, + f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0111}, + f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001}, f4x2_pk_t::type{0b0010}, + f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, + f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1111}, f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0111}, f4x2_pk_t::type{0b1010}, f4x2_pk_t::type{0b0001}, + f4x2_pk_t::type{0b0010}, f4x2_pk_t::type{0b1001}, f4x2_pk_t::type{0b1001}, + f4x2_pk_t::type{0b1111}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ( + right_vec.template AsType()(Number{}).template unpack<>(Number<0>{}), 0); + ASSERT_EQ( + right_vec.template AsType()(Number{}).template unpack<>(Number<1>{}), 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = + f4x2_pk_t{}.pack(test_vec.at(i), test_vec.at(i + 1)); + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).template unpack<>(Number<0>{}), + test_vec.at(i)); + ASSERT_EQ(left_vec.template AsType()(Number{}).template unpack<>(Number<1>{}), + test_vec.at(i + 1)); + }); +} diff --git a/test/data_type/test_fp6.cpp b/test/data_type/test_fp6.cpp new file mode 100644 index 0000000000..cf91e69db3 --- /dev/null +++ b/test/data_type/test_fp6.cpp @@ -0,0 +1,385 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" +#include "ck/utility/scaled_type_convert.hpp" + +using ck::e8m0_bexp_t; +using ck::f6_convert_rne; +using ck::f6_convert_sr; +using ck::f6_t; +using ck::f6x16_pk_t; +using ck::f6x32_pk_t; +using ck::Number; +using ck::scaled_type_convert; +using ck::type_convert; +using ck::vector_type; + +TEST(FP6, NumericLimits) +{ + EXPECT_EQ(ck::NumericLimits::Min(), f6_t(0b001000)); + EXPECT_EQ(ck::NumericLimits::Max(), f6_t(0b011111)); + EXPECT_EQ(ck::NumericLimits::Lowest(), f6_t(0b111111)); + EXPECT_EQ(ck::NumericLimits::MinSubnorm(), f6_t(0b000001)); + EXPECT_EQ(ck::NumericLimits::MaxSubnorm(), f6_t(0b000111)); +} + +TEST(FP6, ConvertFP32Nearest) +{ + // set maximum fp6 value + float max_fp6 = 7.5f; + // convert 0 float to fp6 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f6_convert_rne(0.0f)), 0.0f); + // convert maximal f6_t to float and check if equal to max_fp6 + ASSERT_NEAR(max_fp6, type_convert(f6_convert_rne(max_fp6)), 0.0f); + // convert maximal float to fp6 and back, check if clipped to max_fp6 + ASSERT_NEAR( + max_fp6, type_convert(f6_convert_rne(std::numeric_limits::max())), 0.0f); + // convert float Inf to fp6 and back, check if clipped to max_fp6 + ASSERT_NEAR( + max_fp6, type_convert(f6_convert_rne(std::numeric_limits::infinity())), 0.0f); + // convert float value less than fp6 subnorm to fp6 and back, check if equal to 0.0 + float less_than_subnorm = 0.0625f; + ASSERT_NEAR(0.0f, type_convert(f6_convert_rne(less_than_subnorm)), 0.0f); + // convert float NaN to fp6 and back, check if clipped to max_fp6 + ASSERT_NEAR(max_fp6, + type_convert(f6_convert_rne(std::numeric_limits::quiet_NaN())), + 0.0f); + // positive norm float value to fp6 and back, check if holds + float pos_float = 1.0f; + ASSERT_NEAR(pos_float, type_convert(f6_convert_rne(pos_float)), 0.0f); + // negative norm float value to fp6 and back, check if holds + float neg_float = -1.5f; + ASSERT_NEAR(neg_float, type_convert(f6_convert_rne(neg_float)), 0.0f); + // positive subnorm float value to fp6 and back, check if holds + pos_float = 0.125f; + ASSERT_NEAR(pos_float, type_convert(f6_convert_rne(pos_float)), 0.0f); + // negative subnorm float value to fp6 and back, check if holds + neg_float = -0.25f; + ASSERT_NEAR(neg_float, type_convert(f6_convert_rne(neg_float)), 0.0f); +} + +TEST(FP6, ConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // set maximum fp6 value + float max_fp6 = 7.5f; + // convert 0 float to fp6 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f6_convert_sr(0.0f)), abs_tol); + // convert maximal f6_t to float and check if equal to max_fp6 + ASSERT_NEAR(max_fp6, type_convert(f6_convert_sr(max_fp6)), abs_tol); + // convert maximal float to fp6 and back, check if clipped to max_fp6 + ASSERT_NEAR( + max_fp6, type_convert(f6_convert_sr(std::numeric_limits::max())), abs_tol); + // convert float Inf to fp6 and back, check if clipped to max_fp6 + ASSERT_NEAR(max_fp6, + type_convert(f6_convert_sr(std::numeric_limits::infinity())), + abs_tol); + // convert float NaN to fp6 and back, check if clipped to max_fp6 + ASSERT_NEAR(max_fp6, + type_convert(f6_convert_sr(std::numeric_limits::quiet_NaN())), + abs_tol); + // positive norm float value to fp6 and back, check if holds + float pos_float = 1.0f; + ASSERT_NEAR(pos_float, type_convert(f6_convert_sr(pos_float)), abs_tol); + // negative norm float value to fp6 and back, check if holds + float neg_float = -1.5f; + ASSERT_NEAR(neg_float, type_convert(f6_convert_sr(neg_float)), abs_tol); + // positive subnorm float value to fp6 and back, check if holds + pos_float = 0.125f; + ASSERT_NEAR(pos_float, type_convert(f6_convert_sr(pos_float)), abs_tol); + // negative subnorm float value to fp6 and back, check if holds + neg_float = -0.25f; + ASSERT_NEAR(neg_float, type_convert(f6_convert_sr(neg_float)), abs_tol); +} + +TEST(FP6, ScaledConvertFP32Nearest) +{ + // set maximum scale + float max_scale = type_convert(ck::NumericLimits::Max()); // 0xFE -> float + // set minimum scale + float min_scale = type_convert(ck::NumericLimits::Min()); // 0x00 -> float + // set arbitrary scale to 256.0 + float test_scale = 256.0f; // 0b10000111 + // convert 0 float to fp6 and back with maximal scale, check if holds + ASSERT_NEAR( + 0.0f, scaled_type_convert(e8m0_bexp_t(max_scale), f6_convert_rne(0.0f)), 0.0f); + // convert 0 float to fp6 and back with minimal scale, check if holds + ASSERT_NEAR( + 0.0f, scaled_type_convert(e8m0_bexp_t(min_scale), f6_convert_rne(0.0f)), 0.0f); + // positive norm float value to fp6 and back with various scales, check if holds + float pos_float = 1.0f; + ASSERT_NEAR(pos_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f6_convert_rne(pos_float)), + 0.0f); + ASSERT_NEAR(pos_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f6_convert_rne(pos_float)), + 0.0f); + ASSERT_NEAR(pos_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f6_convert_rne(pos_float)), + 0.0f); + // negative norm float value to fp6 and back with various scales, check if holds + float neg_float = -1.5f; + ASSERT_NEAR(neg_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f6_convert_rne(neg_float)), + 0.0f); + ASSERT_NEAR(neg_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f6_convert_rne(neg_float)), + 0.0f); + ASSERT_NEAR(neg_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f6_convert_rne(neg_float)), + 0.0f); + // positive subnorm float value to fp6 and back with various scales, check if holds + pos_float = 0.125f; + ASSERT_NEAR(pos_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f6_convert_rne(pos_float)), + 0.0f); + ASSERT_NEAR(pos_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f6_convert_rne(pos_float)), + 0.0f); + ASSERT_NEAR(pos_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f6_convert_rne(pos_float)), + 0.0f); + // negative subnorm float value to fp6 and back with various scales, check if holds + neg_float = -0.25f; + ASSERT_NEAR(neg_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f6_convert_rne(neg_float)), + 0.0f); + ASSERT_NEAR(neg_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f6_convert_rne(neg_float)), + 0.0f); + ASSERT_NEAR(neg_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f6_convert_rne(neg_float)), + 0.0f); +} + +TEST(FP6, ScaledConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // set maximum scale + float max_scale = type_convert(ck::NumericLimits::Max()); // 0xFE -> float + // set minimum scale + float min_scale = type_convert(ck::NumericLimits::Min()); // 0x00 -> float + // set arbitrary scale to 256.0 + float test_scale = 256.0f; // 0b10000111 + // convert 0 float to fp6 and back with maximal scale, check if holds + ASSERT_NEAR( + 0.0f, scaled_type_convert(e8m0_bexp_t(max_scale), f6_convert_sr(0.0f)), abs_tol); + // convert 0 float to fp6 and back with minimal scale, check if holds + ASSERT_NEAR( + 0.0f, scaled_type_convert(e8m0_bexp_t(min_scale), f6_convert_sr(0.0f)), abs_tol); + // positive norm float value to fp6 and back with various scales, check if holds + float pos_float = 1.0f; + ASSERT_NEAR(pos_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f6_convert_sr(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f6_convert_sr(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f6_convert_sr(pos_float)), + abs_tol); + // negative norm float value to fp6 and back with various scales, check if holds + float neg_float = -1.5f; + ASSERT_NEAR(neg_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f6_convert_sr(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f6_convert_sr(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f6_convert_sr(neg_float)), + abs_tol); + // positive subnorm float value to fp6 and back with various scales, check if holds + pos_float = 0.125f; + ASSERT_NEAR(pos_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f6_convert_sr(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f6_convert_sr(pos_float)), + abs_tol); + ASSERT_NEAR(pos_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f6_convert_sr(pos_float)), + abs_tol); + // negative subnorm float value to fp6 and back with various scales, check if holds + neg_float = -0.25f; + ASSERT_NEAR(neg_float * test_scale, + scaled_type_convert(e8m0_bexp_t(test_scale), f6_convert_sr(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * max_scale, + scaled_type_convert(e8m0_bexp_t(max_scale), f6_convert_sr(neg_float)), + abs_tol); + ASSERT_NEAR(neg_float * min_scale, + scaled_type_convert(e8m0_bexp_t(min_scale), f6_convert_sr(neg_float)), + abs_tol); +} + +TEST(FP6, TestSize) +{ + ASSERT_EQ(1, sizeof(f6_t)); + ASSERT_EQ(12, sizeof(f6x16_pk_t)); + ASSERT_EQ(24, sizeof(f6x32_pk_t)); + ASSERT_EQ(16, sizeof(vector_type)); + ASSERT_EQ(32, sizeof(vector_type)); + ASSERT_EQ(32, sizeof(vector_type)); +} + +TEST(FP6, TestAlignment) +{ + ASSERT_EQ(1, alignof(f6_t)); + ASSERT_EQ(4, alignof(f6x16_pk_t)); + ASSERT_EQ(4, alignof(f6x32_pk_t)); + ASSERT_EQ(16, alignof(vector_type)); + ASSERT_EQ(32, alignof(vector_type)); + ASSERT_EQ(32, alignof(vector_type)); +} + +// test vector of 1 f6x16_pk_t, contains 16 f6_t +TEST(FP6, TestAsType16x1) +{ + // test size + const int vector_size = 1; + const int packed_size = 16; + typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); + test_vec_t test_vec = {f6_t(0b000000), + f6_t(0b100000), + f6_t(0b000001), + f6_t(0b100001), + f6_t(0b000010), + f6_t(0b100010), + f6_t(0b000011), + f6_t(0b100011), + f6_t(0b000100), + f6_t(0b100100), + f6_t(0b000101), + f6_t(0b100101), + f6_t(0b000110), + f6_t(0b100110), + f6_t(0b001011), + f6_t(0b101011)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, packed_size, 1>{}([&](auto i) { + ASSERT_EQ( + right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), 0); + }); + // assign test values to the vector + ck::static_for<0, vector_size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = f6x16_pk_t{}.pack(test_vec); + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, packed_size, 1>{}([&](auto i) { + ASSERT_EQ( + left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), + static_cast(test_vec[static_cast(i)])); + }); +} + +// test vector of 2 f6x16_pk_t, contains 32 f6_t +TEST(FP6, TestAsType16x2) +{ + // test size + const int vector_size = 2; + const int packed_size = 16; + typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); + test_vec_t test_vec[2]; + test_vec[0] = {f6_t(0b000000), + f6_t(0b100000), + f6_t(0b000001), + f6_t(0b100001), + f6_t(0b000010), + f6_t(0b100010), + f6_t(0b000011), + f6_t(0b100011), + f6_t(0b000100), + f6_t(0b100100), + f6_t(0b000101), + f6_t(0b100101), + f6_t(0b000110), + f6_t(0b100110), + f6_t(0b001011), + f6_t(0b101011)}; + test_vec[1] = {f6_t(0b010000), + f6_t(0b110000), + f6_t(0b010001), + f6_t(0b110001), + f6_t(0b010010), + f6_t(0b110010), + f6_t(0b010011), + f6_t(0b110011), + f6_t(0b010100), + f6_t(0b110100), + f6_t(0b010101), + f6_t(0b110101), + f6_t(0b010110), + f6_t(0b110110), + f6_t(0b011011), + f6_t(0b111011)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { + ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { + ASSERT_EQ(right_vec.template AsType()(Number{}) + .template unpack<>(Number{}), + 0); + }); + }); + // assign test values to the vector + ck::static_for<0, vector_size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = f6x16_pk_t{}.pack(test_vec[i]); + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { + ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { + ASSERT_EQ(left_vec.template AsType()(Number{}) + .template unpack<>(Number{}), + static_cast(test_vec[idx_vector][static_cast(idx_element)])); + }); + }); +} + +// test vector of 1 f6x32_pk_t, contains 32 f6_t +TEST(FP6, TestAsType32x1) +{ + // test size + const int vector_size = 1; + const int packed_size = 32; + typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); + test_vec_t test_vec = {f6_t(0b000000), f6_t(0b100000), f6_t(0b000001), f6_t(0b100001), + f6_t(0b000010), f6_t(0b100010), f6_t(0b000011), f6_t(0b100011), + f6_t(0b000100), f6_t(0b100100), f6_t(0b000101), f6_t(0b100101), + f6_t(0b000110), f6_t(0b100110), f6_t(0b001011), f6_t(0b101011), + f6_t(0b010000), f6_t(0b110000), f6_t(0b010001), f6_t(0b110001), + f6_t(0b010010), f6_t(0b110010), f6_t(0b010011), f6_t(0b110011), + f6_t(0b010100), f6_t(0b110100), f6_t(0b010101), f6_t(0b110101), + f6_t(0b010110), f6_t(0b110110), f6_t(0b011011), f6_t(0b111011)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, packed_size, 1>{}([&](auto i) { + ASSERT_EQ( + right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), 0); + }); + // assign test values to the vector + ck::static_for<0, vector_size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = f6x32_pk_t{}.pack(test_vec); + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, packed_size, 1>{}([&](auto i) { + ASSERT_EQ( + left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), + static_cast(test_vec[static_cast(i)])); + }); +} diff --git a/test/data_type/test_fp8_ocp.cpp b/test/data_type/test_fp8_ocp.cpp index a8077f1bdf..944dd89930 100644 --- a/test/data_type/test_fp8_ocp.cpp +++ b/test/data_type/test_fp8_ocp.cpp @@ -60,8 +60,8 @@ TEST(FP8OCP, ConvertFP32Nearest) float neg_float = -0.015625f; //-2^-6 ASSERT_NEAR(neg_float, type_convert(f8_convert_rne(neg_float)), 0.0f); - // positive subnorm float value to fp8 and back, check if holds - pos_float = 0.00390625f; + // positive subnorm fp8 value to fp8 and back, check if holds + pos_float = 0.00390625f; // 2^-8 ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); // min subnorm fp8 value to fp8 and back, check if holds diff --git a/test/data_type/test_mx_bf8.cpp b/test/data_type/test_mx_bf8.cpp new file mode 100644 index 0000000000..9f947e1d78 --- /dev/null +++ b/test/data_type/test_mx_bf8.cpp @@ -0,0 +1,654 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/library/utility/device_memory.hpp" +#include "ck/utility/scaled_type_convert.hpp" + +using ck::bf8_ocp_t; +using ck::bf8x16_ocp_t; +using ck::bf8x2_ocp_t; +using ck::bf8x32_ocp_t; +using ck::e8m0_bexp_t; +using ck::float16_t; +using ck::float2_t; +using ck::float32_t; +using ck::mxf8_convert_rne; +using ck::mxf8_convert_sr; +using ck::scaled_type_convert; +using ck::type_convert; + +constexpr uint64_t test_size = 256 * 256 + 2 + 4 + 6; + +/** + * @brief Tests conversion of BF8 values to float using E8M0 exponent scaling. + * + * This function performs a series of conversions from BF8 values to float values using + * E8M0 exponent scaling. It handles all possible combinations of E8M0 and BF8 values, + * as well as specific vector and rounding conversions. + * + * @param N The maximum number of conversions to perform. + * @param p_test Pointer to the output array where the converted float values will be stored. + * @param p_completed Pointer to a variable that tracks the number of completed conversions. + * + * @note If either p_test or p_completed is nullptr, the function will return immediately. + * @note The function will stop converting if the number of conversions reaches N. + * @note First 256*256 conversions are for all possible combinations of E8M0 and BF8 values that are + * stored in memory sequentially with BF8 values varying faster. + * + * The function performs the following conversions: + * - All possible combinations of E8M0 and BF8 values. [256x256] + * - Vector conversions bf8x2 -> f32x2. [2] + * - Vector conversions f32x2 -> bf8x2 rne. [2] + * - Vector conversions f32x2 -> bf8x2 sr. [2] + * - Round to nearest even conversions for specific float values. [6] + * + * The results are stored in the p_test array, and the number of completed conversions + * is updated in the p_completed variable. + */ +__host__ __device__ void +test_mx_bf8_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + // All possible combinations of E8M0 and BF8 + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + uint8_t bf8_uid = static_cast(bf8_id); + auto v = scaled_type_convert(e8m0_bexp_t(exp_id), bf8_ocp_t{bf8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + } + + /// Test vector conversions + // bf8x2 -> f32x2 + bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16 + auto scale = e8m0_bexp_t(8.0f); + + float2_t f32x2 = scaled_type_convert(scale, bf8x2); + p_test[i++] = f32x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f32x2[1]; + if(i >= N) + { + return; + } + + // f32x2 -> bf8x2 + f32x2 = {-8.0f, 4.0f}; + auto scale2 = e8m0_bexp_t(2.0f); + + bf8x2 = mxf8_convert_rne(f32x2, type_convert(scale2)); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + auto scale4 = e8m0_bexp_t(4.0f); + + bf8x2 = mxf8_convert_sr(f32x2, type_convert(scale4)); // expect {-2, 1} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-2f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 1f + if(i >= N) + { + return; + } + + /// Test round to nearest even + + p_test[i++] = type_convert(mxf8_convert_rne(1024.0f, 4.0f)); // 1024/4 + if(i >= N) + { + return; + } + + p_test[i++] = type_convert( + mxf8_convert_rne(std::numeric_limits::quiet_NaN(), 4.0f)); // => NaN + if(i >= N) + { + return; + } + + p_test[i++] = type_convert(mxf8_convert_rne( + std::numeric_limits::infinity(), 2.0f)); // => BF8 Inf on device + if(i >= N) + { + return; + } + + // 31000/0.5 > 57344 => BF8 Inf on device + p_test[i++] = type_convert(mxf8_convert_rne(31000.0f, 0.5f)); + if(i >= N) + { + return; + } + + // -31000/0.5 < -57344 => -BF8 Inf on device + p_test[i++] = type_convert(mxf8_convert_rne(-31000.0f, 0.5f)); + if(i >= N) + { + return; + } + + p_test[i++] = type_convert( + mxf8_convert_rne(powf(2.0f, 16.0f), 4.0f)); // 2^16/4 = 65536/4 + if(i >= N) + { + return; + } +} + +TEST(MXBF8, HostScaledConvert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_mx_bf8_scaled_convert(test_size, out.data(), &completed); + + // V = X * P; X - E8M0 scale, P - BF8 + + // If X = NaN, then V = NaN regardless of P + uint8_t e8m0_nan_id = ck::NumericLimits::QuietNaN().data; + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + auto idx = e8m0_nan_id * 256 + bf8_id; + ASSERT_TRUE(std::isnan(out[idx])); + } + + // If P in {Inf, NaN}, then V = P + std::set bf8_spec_ids; + bf8_spec_ids.insert(0b11111111); // -NaN + bf8_spec_ids.insert(0b01111111); // +NaN + bf8_spec_ids.insert(0b11111101); // -NaN + bf8_spec_ids.insert(0b01111101); // +NaN + bf8_spec_ids.insert(0b11111110); // -NaN + bf8_spec_ids.insert(0b01111110); // +NaN + bf8_spec_ids.insert(0b11111100); // -inf + bf8_spec_ids.insert(0b01111100); // +inf + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + if(exp_id == e8m0_nan_id) + continue; + for(auto bf8_spec_id : bf8_spec_ids) + { + auto idx = exp_id * 256 + bf8_spec_id; + + if(std::isnan(type_convert(bf8_ocp_t{bf8_spec_id}))) + { + ASSERT_TRUE(std::isnan(out[idx])) + << "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl + << type_convert(e8m0_bexp_t(exp_id)) << " * " + << type_convert(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx]; + } + else + { + ASSERT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_spec_id})) + << "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl + << type_convert(e8m0_bexp_t(exp_id)) << " * " + << type_convert(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx]; + } + } + } + + // V = X * P; X, P - finite + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + if(exp_id == e8m0_nan_id) + continue; + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_spec_ids.find(bf8_id) != bf8_spec_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = exp_id * 256 + bf8_uid; + ASSERT_FLOAT_EQ(out[idx], + type_convert(e8m0_bexp_t(exp_id)) * + type_convert(bf8_ocp_t{bf8_uid})) + << "exp_id: " << exp_id << " bf8_id: " << bf8_uid << std::endl + << type_convert(e8m0_bexp_t(exp_id)) << " * " + << type_convert(bf8_ocp_t{bf8_uid}); + } + } + + /// Test vector conversions + + auto i = 256 * 256; + + // bf8x2 -> f32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -11.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -13.0f)); + + // f32x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -2.0f); + EXPECT_EQ(out[i++], 1.0f); + + /// Test round to nearest even + EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1]; + EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; + + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Lowest())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], powf(2.0f, 14.0f)) << "out[i-1]: " << out[i - 1]; + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void test_mx_bf8_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + test_mx_bf8_scaled_convert(N, p_test, p_completed); +} + +TEST(MXBF8, DeviceScaledConvert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_bf8_device_scaled_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + // V = X * P; X - E8M0 scale, P - BF8 + + // If X = NaN, then V = NaN regardless of P + uint8_t e8m0_nan_id = ck::NumericLimits::QuietNaN().data; + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + auto idx = e8m0_nan_id * 256 + bf8_id; + ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; + } + + // If P in {Inf, NaN}, then V = P + std::set bf8_spec_ids; + bf8_spec_ids.insert(0b11111111); //-NaN + bf8_spec_ids.insert(0b01111111); // +NaN + bf8_spec_ids.insert(0b11111101); //-NaN + bf8_spec_ids.insert(0b01111101); // +NaN + bf8_spec_ids.insert(0b11111110); //-NaN + bf8_spec_ids.insert(0b01111110); // +NaN + bf8_spec_ids.insert(0b11111100); //-inf + bf8_spec_ids.insert(0b01111100); // +inf + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + if(exp_id == e8m0_nan_id) + continue; + for(auto bf8_spec_id : bf8_spec_ids) + { + auto idx = exp_id * 256 + bf8_spec_id; + + if(std::isnan(type_convert(bf8_ocp_t{bf8_spec_id}))) + { + ASSERT_TRUE(std::isnan(out[idx])) + << "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl + << type_convert(e8m0_bexp_t(exp_id)) << " * " + << type_convert(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx]; + } + else + { + ASSERT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_spec_id})) + << "exp_id: " << exp_id << " bf8_id: " << bf8_spec_id << std::endl + << type_convert(e8m0_bexp_t(exp_id)) << " * " + << type_convert(bf8_ocp_t{bf8_spec_id}) << " != " << out[idx]; + } + } + } + + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + if(exp_id == e8m0_nan_id) + continue; + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_spec_ids.find(bf8_id) != bf8_spec_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = exp_id * 256 + bf8_uid; + ASSERT_FLOAT_EQ(out[idx], + type_convert(e8m0_bexp_t(exp_id)) * + type_convert(bf8_ocp_t{bf8_uid})) + << "exp_id: " << exp_id << " bf8_id: " << bf8_uid << std::endl + << type_convert(e8m0_bexp_t(exp_id)) << " * " + << type_convert(bf8_ocp_t{bf8_uid}); + } + } + + /// Test vector conversions + + auto i = 256 * 256; + + // bf8x2 -> f32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -11.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -13.0f)); + + // f32x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -2.0f); + EXPECT_EQ(out[i++], 1.0f); + + /// Test round to nearest even + EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1]; + EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; +#if 1 + EXPECT_TRUE(std::isinf(out[i++])) << "out[i-1]: " << out[i - 1]; + EXPECT_TRUE(std::isinf(out[i++])) << "out[i-1]: " << out[i - 1]; + EXPECT_TRUE(std::isinf(out[i++])) << "out[i-1]: " << out[i - 1]; +#else + // NOTE: Host and Device have different behavior. + // Device returns Infs, while Host returns Max (saturation to finite value). + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Lowest())) + << "out[i-1]: " << out[i - 1]; +#endif + EXPECT_EQ(out[i++], powf(2.0f, 14.0f)) << "out[i-1]: " << out[i - 1]; + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ float vec16_generator(ck::index_t i) { return powf(-1.0f, i) * powf(2.0f, i); } + +__global__ void test_mx_bf8x16_device_scaled_convert(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 16; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + auto scale2 = e8m0_bexp_t(2.0f); + + bf8x16_ocp_t bf8x16{}; + float16_t float16{}; + ck::static_for<0, N, 1>{}( + [&](auto ii) { float16[static_cast(ii)] = vec16_generator(ii); }); + + bf8x16 = scaled_type_convert(scale2, float16); + + ck::static_for<0, N, 1>{}([&](auto ii) { + p_test[i++] = type_convert(bf8x16.AsType()(ck::Number{})); + }); +} + +TEST(MXBF8, DeviceF32x16ToBF8x16ScaledConvert) +{ + constexpr int N = 16; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_bf8x16_device_scaled_convert<<<1, 1>>>( + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + auto i = 0; + + ck::static_for<0, N, 1>{}([&](auto ii) { + EXPECT_EQ(out[i++], vec16_generator(ii) / 2.0f) << "ii: " << ii << std::endl; + }); + + EXPECT_EQ(N, completed); + EXPECT_EQ(N, i); +} + +__host__ __device__ float vec32_generator(ck::index_t i) +{ + if(i < 16) + { + return vec16_generator(i % 16); + } + else + { + return 1.5f * vec16_generator(i % 16); + } +} + +__global__ void test_mx_bf8x32_device_scaled_convert(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 32; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + auto scale2 = e8m0_bexp_t(2.0f); + + bf8x32_ocp_t bf8x32{}; + float32_t float32{}; + ck::static_for<0, N, 1>{}( + [&](auto ii) { float32[static_cast(ii)] = vec32_generator(ii); }); + + bf8x32 = mxf8_convert_rne(float32, type_convert(scale2)); + + ck::static_for<0, N, 1>{}([&](auto ii) { + p_test[i++] = type_convert(bf8x32.AsType()(ck::Number{})); + }); +} + +TEST(MXBF8, DeviceF32x32ToBF8x32ScaledConvert) +{ + constexpr int N = 32; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_bf8x32_device_scaled_convert<<<1, 1>>>( + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + auto i = 0; + + ck::static_for<0, N, 1>{}([&](auto ii) { + EXPECT_EQ(out[i++], vec32_generator(ii) / 2.0f) << "ii: " << ii << std::endl; + }); + + EXPECT_EQ(N, completed); + EXPECT_EQ(N, i); +} + +__global__ void test_mx_bf8x32_device_scaled_convert_sr(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 32; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + auto scale2 = e8m0_bexp_t(8.0f); + + bf8x32_ocp_t bf8x32{}; + float32_t float32{}; + ck::static_for<0, N, 1>{}( + [&](auto ii) { float32[static_cast(ii)] = vec32_generator(ii); }); + + bf8x32 = mxf8_convert_sr(float32, type_convert(scale2)); + + ck::static_for<0, N, 1>{}([&](auto ii) { + p_test[i++] = type_convert(bf8x32.AsType()(ck::Number{})); + }); +} + +TEST(MXBF8, DeviceF32x32ToBF8x32ScaledConvertSR) +{ + constexpr int N = 32; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_bf8x32_device_scaled_convert_sr<<<1, 1>>>( + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + auto i = 0; + + ck::static_for<0, N, 1>{}([&](auto ii) { + EXPECT_EQ(out[i++], vec32_generator(ii) / 8.0f) << "ii: " << ii << std::endl; + }); + + EXPECT_EQ(N, completed); + EXPECT_EQ(N, i); +} + +__global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 32; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + auto scale2 = e8m0_bexp_t(4.0f); + + bf8x32_ocp_t bf8x32{}; + float32_t float32{}; + ck::static_for<0, N, 1>{}([&](auto ii) { + bf8x32.AsType()(ii) = type_convert(vec32_generator(ii) / 16.0f); + }); + + float32 = scaled_type_convert(scale2, bf8x32); + + ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32[static_cast(ii)]; }); +} + +TEST(MXBF8, DeviceBF8x32ToF32x32ScaledConvert) +{ + constexpr int N = 32; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_f32x32_device_scaled_convert<<<1, 1>>>( + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + auto i = 0; + + ck::static_for<0, N, 1>{}([&](auto ii) { + EXPECT_EQ(out[i++], vec32_generator(ii) / 4.0f) << "ii: " << ii << std::endl; + }); + + EXPECT_EQ(N, completed); + EXPECT_EQ(N, i); +} diff --git a/test/data_type/test_mx_fp8.cpp b/test/data_type/test_mx_fp8.cpp new file mode 100644 index 0000000000..3a2bd4b88b --- /dev/null +++ b/test/data_type/test_mx_fp8.cpp @@ -0,0 +1,616 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/library/utility/device_memory.hpp" +#include "ck/utility/scaled_type_convert.hpp" + +using ck::e8m0_bexp_t; +using ck::f8_ocp_t; +using ck::f8x16_ocp_t; +using ck::f8x2_ocp_t; +using ck::f8x32_ocp_t; +using ck::float16_t; +using ck::float2_t; +using ck::float32_t; +using ck::mxf8_convert_rne; +using ck::mxf8_convert_sr; +using ck::scaled_type_convert; +using ck::type_convert; +using ck::fp8_impl::fp8x2_storage_t; + +constexpr uint64_t test_size = 256 * 256 + 2 + 4 + 6; + +/** + * @brief Tests conversion of FP8 values to float using E8M0 exponent scaling. + * + * This function performs a series of conversions from FP8 values to float values using + * E8M0 exponent scaling. It handles all possible combinations of E8M0 and FP8 values, + * as well as specific vector and rounding conversions. + * + * @param N The maximum number of conversions to perform. + * @param p_test Pointer to the output array where the converted float values will be stored. + * @param p_completed Pointer to a variable that tracks the number of completed conversions. + * + * @note If either p_test or p_completed is nullptr, the function will return immediately. + * @note The function will stop converting if the number of conversions reaches N. + * @note First 256*256 conversions are for all possible combinations of E8M0 and FP8 values that are + * stored in memory sequentially with FP8 values varying faster. + * + * The function performs the following conversions: + * - All possible combinations of E8M0 and FP8 values. [256x256] + * - Vector conversions f8x2 -> f32x2. [2] + * - Vector conversions f32x2 -> f8x2 rne. [2] + * - Vector conversions f32x2 -> f8x2 sr. [2] + * - Round to nearest even conversions for specific float values. [6] + * + * The results are stored in the p_test array, and the number of completed conversions + * is updated in the p_completed variable. + */ +__host__ __device__ void +test_mx_fp8_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + // All possible combinations of E8M0 and FP8 + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + uint8_t fp8_uid = static_cast(fp8_id); + auto v = scaled_type_convert(e8m0_bexp_t(exp_id), f8_ocp_t{fp8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + } + + /// Test vector conversions + // f8x2 -> f32x2 + f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9 + auto scale2 = e8m0_bexp_t(2.0f); + + float2_t f32x2 = scaled_type_convert(scale2, fp8x2); + p_test[i++] = f32x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f32x2[1]; + if(i >= N) + { + return; + } + + // f32x2 -> f8x2 + f32x2 = {-8.0f, 4.0f}; + fp8x2 = mxf8_convert_rne(f32x2, type_convert(scale2)); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + auto scale4 = e8m0_bexp_t(4.0f); + + fp8x2 = mxf8_convert_sr(f32x2, type_convert(scale4)); // expect {-2, 1} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-2f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 1f + if(i >= N) + { + return; + } + + /// Test round to nearest even + + p_test[i++] = type_convert(mxf8_convert_rne(1024.0f, 4.0f)); // 1024/4 + if(i >= N) + { + return; + } + + p_test[i++] = type_convert( + mxf8_convert_rne(std::numeric_limits::quiet_NaN(), 4.0f)); // => NaN + if(i >= N) + { + return; + } + + // Inf/2 > 448 => NaN on device + p_test[i++] = type_convert( + mxf8_convert_rne(std::numeric_limits::infinity(), 2.0f)); + if(i >= N) + { + return; + } + + // 256/0.5 > 448 => NaN on device + p_test[i++] = type_convert(mxf8_convert_rne(256.0f, 0.5f)); + if(i >= N) + { + return; + } + + // -256/0.5 < -448 => NaN on device + p_test[i++] = type_convert(mxf8_convert_rne(-256.0f, 0.5f)); + if(i >= N) + { + return; + } + + // proper scale selection 2^13 < 10000; 2^8 < 448 => scale = 2^(13-8) = 2^5 + p_test[i++] = + type_convert(mxf8_convert_rne(10000.0f, 32.0f)); // 10000/32 = 312.5 + if(i >= N) + { + return; + } +} + +TEST(MXFP8, HostScaledConvert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_mx_fp8_scaled_convert(test_size, out.data(), &completed); + + // V = X * P; X - E8M0 scale, P - FP8 + + // If X = NaN, then V = NaN regardless of P + uint8_t e8m0_nan_id = ck::NumericLimits::QuietNaN().data; + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + auto idx = e8m0_nan_id * 256 + fp8_id; + ASSERT_TRUE(std::isnan(out[idx])); + } + + // If P in {Inf, NaN}, then V = P + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + if(exp_id == e8m0_nan_id) + continue; + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = exp_id * 256 + fp8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])); + } + } + + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + if(exp_id == e8m0_nan_id) + continue; + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = exp_id * 256 + fp8_uid; + ASSERT_FLOAT_EQ(out[idx], + type_convert(e8m0_bexp_t(exp_id)) * + type_convert(f8_ocp_t{fp8_uid})) + << "exp_id: " << exp_id << " fp8_id: " << fp8_id << std::endl + << type_convert(e8m0_bexp_t(exp_id)) << " * " + << type_convert(f8_ocp_t{fp8_uid}); + } + } + + /// Test vector conversions + + auto i = 256 * 256; + + // f8x2 -> f32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -5.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -8.0f)); + + // f32x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -2.0f); + EXPECT_EQ(out[i++], 1.0f); + + /// Test round to nearest even + EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1]; + EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Lowest())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(type_convert(312.5f))) + << "out[i-1]: " << out[i - 1]; + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void test_mx_fp8_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + test_mx_fp8_scaled_convert(N, p_test, p_completed); +} + +TEST(MXFP8, DeviceScaledConvert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_fp8_device_scaled_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + // V = X * P; X - E8M0 scale, P - FP8 + + // If X = NaN, then V = NaN regardless of P + uint8_t e8m0_nan_id = ck::NumericLimits::QuietNaN().data; + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + auto idx = e8m0_nan_id * 256 + fp8_id; + ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; + } + + // If P in {Inf, NaN}, then V = P + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + if(exp_id == e8m0_nan_id) + continue; + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = exp_id * 256 + fp8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; + } + } + + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + if(exp_id == e8m0_nan_id) + continue; + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = exp_id * 256 + fp8_uid; + ASSERT_FLOAT_EQ(out[idx], + type_convert(e8m0_bexp_t(exp_id)) * + type_convert(f8_ocp_t{fp8_uid})) + << "exp_id: " << exp_id << " fp8_id: " << fp8_id << std::endl + << type_convert(e8m0_bexp_t(exp_id)) << " * " + << type_convert(f8_ocp_t{fp8_uid}); + } + } + + /// Test vector conversions + + auto i = 256 * 256; + + // f8x2 -> f32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -5.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -8.0f)); + + // f32x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -2.0f); + EXPECT_EQ(out[i++], 1.0f); + + /// Test round to nearest even + EXPECT_EQ(out[i++], 1024.0f / 4.0f) << "out[i-1]: " << out[i - 1]; + EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; +#if 1 + EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; + EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; + EXPECT_TRUE(std::isnan(out[i++])) << "out[i-1]: " << out[i - 1]; +#else + // NOTE: Host and Device have different behavior. + // Device returns NaN, while Host returns Max (saturation to finite value). + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Lowest())) + << "out[i-1]: " << out[i - 1]; +#endif + EXPECT_EQ(out[i++], type_convert(type_convert(312.5f))) + << "out[i-1]: " << out[i - 1]; + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ float vec16_generator(ck::index_t i) +{ + return (i < 8 ? -1.0 : 1.0) * powf(2.0f, i % 8); +} + +__global__ void test_mx_fp8x16_device_scaled_convert(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 16; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + auto scale2 = e8m0_bexp_t(2.0f); + + f8x16_ocp_t fp8x16{}; + float16_t float16{}; + ck::static_for<0, N, 1>{}( + [&](auto ii) { float16[static_cast(ii)] = vec16_generator(ii); }); + + fp8x16 = scaled_type_convert(scale2, float16); + + ck::static_for<0, N, 1>{}([&](auto ii) { + p_test[i++] = type_convert(fp8x16.AsType()(ck::Number{})); + }); +} + +TEST(MXFP8, DeviceF32x16ToF8x16ScaledConvert) +{ + constexpr int N = 16; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_fp8x16_device_scaled_convert<<<1, 1>>>( + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + auto i = 0; + + ck::static_for<0, N, 1>{}([&](auto ii) { + EXPECT_EQ(out[i++], vec16_generator(ii) / 2.0f) << "ii: " << ii << std::endl; + }); + + EXPECT_EQ(N, completed); + EXPECT_EQ(N, i); +} + +__host__ __device__ float vec32_generator(ck::index_t i) +{ + if(i < 16) + { + return vec16_generator(i % 16); + } + else + { + return 1.5f * vec16_generator(i % 16); + } +} + +__global__ void test_mx_fp8x32_device_scaled_convert(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 32; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + auto scale2 = e8m0_bexp_t(2.0f); + + f8x32_ocp_t fp8x32{}; + float32_t float32{}; + ck::static_for<0, N, 1>{}( + [&](auto ii) { float32[static_cast(ii)] = vec32_generator(ii); }); + + fp8x32 = mxf8_convert_rne(float32, type_convert(scale2)); + + ck::static_for<0, N, 1>{}( + [&](auto ii) { p_test[i++] = type_convert(fp8x32.AsType()(ii)); }); +} + +TEST(MXFP8, DeviceF32x32ToF8x32ScaledConvert) +{ + constexpr int N = 32; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_fp8x32_device_scaled_convert<<<1, 1>>>( + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + auto i = 0; + + ck::static_for<0, N, 1>{}([&](auto ii) { + EXPECT_EQ(out[i++], vec32_generator(ii) / 2.0f) << "ii: " << ii << std::endl; + }); + + EXPECT_EQ(N, completed); + EXPECT_EQ(N, i); +} + +__global__ void test_mx_fp8x32_device_scaled_convert_sr(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 32; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + auto scale2 = e8m0_bexp_t(8.0f); + + f8x32_ocp_t fp8x32{}; + float32_t float32{}; + ck::static_for<0, N, 1>{}( + [&](auto ii) { float32[static_cast(ii)] = vec32_generator(ii); }); + + fp8x32 = mxf8_convert_sr(float32, type_convert(scale2)); + + ck::static_for<0, N, 1>{}( + [&](auto ii) { p_test[i++] = type_convert(fp8x32.AsType()(ii)); }); +} + +TEST(MXFP8, DeviceF32x32ToF8x32ScaledConvertSR) +{ + constexpr int N = 32; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_fp8x32_device_scaled_convert_sr<<<1, 1>>>( + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + auto i = 0; + + ck::static_for<0, N, 1>{}([&](auto ii) { + EXPECT_EQ(out[i++], vec32_generator(ii) / 8.0f) << "ii: " << ii << std::endl; + }); + + EXPECT_EQ(N, completed); + EXPECT_EQ(N, i); +} + +__global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 32; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + auto scale2 = e8m0_bexp_t(4.0f); + + f8x32_ocp_t fp8x32{}; + float32_t float32{}; + ck::static_for<0, N, 1>{}([&](auto ii) { + fp8x32.AsType()(ii) = type_convert(vec32_generator(ii) / 16.0f); + }); + + float32 = scaled_type_convert(scale2, fp8x32); + + ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32[static_cast(ii)]; }); +} + +TEST(MXFP8, DeviceF8x32ToF32x32ScaledConvert) +{ + constexpr int N = 32; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_f32x32_device_scaled_convert<<<1, 1>>>( + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + auto i = 0; + + ck::static_for<0, N, 1>{}([&](auto ii) { + EXPECT_EQ(out[i++], vec32_generator(ii) / 4.0f) << "ii: " << ii << std::endl; + }); + + EXPECT_EQ(N, completed); + EXPECT_EQ(N, i); +} diff --git a/test/mx_mfma_op/CMakeLists.txt b/test/mx_mfma_op/CMakeLists.txt new file mode 100644 index 0000000000..6715265ae6 --- /dev/null +++ b/test/mx_mfma_op/CMakeLists.txt @@ -0,0 +1,9 @@ +add_custom_target(test_mx_mfma) + +add_gtest_executable(test_mx_mfma_op mx_mfma_op.cpp) +if(result EQUAL 0) + target_link_libraries(test_mx_mfma_op PRIVATE utility) +endif() +add_dependencies(test_mx_mfma test_mx_mfma_op) + + diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp new file mode 100644 index 0000000000..cc612794f4 --- /dev/null +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" + +#include "mx_mfma_op.hpp" + +using ck::e8m0_bexp_t; +using ck::f8_t; +using ck::half_t; +using ck::type_convert; + +/** + * @brief Run the test for the given MFMA instruction + * + * @param init - selects initialization algorithm for A and B tensors + */ +template +bool run_mfma_test(ck::index_t init) +{ + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + + using AccType = float; // only MFMA_F32 instructions supported + using CPUAccType = AccType; + + ck::mfma_type(mfma)> mfma_instr; + constexpr auto BLOCK_M = mfma_instr.m_per_blk; + constexpr auto BLOCK_N = mfma_instr.n_per_blk; + constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; + + const auto mx_mfma_kernel = ck::matmul; + + bool pass = true; + + pass = ck::mfma_test::TestMFMA{}(mx_mfma_kernel, init); + + return pass; +} + +TEST(MFMA, FP8MFMA16x16x128) +{ + auto AB_init = 0; + auto pass = run_mfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, FP8MFMA32x32x64) +{ + auto AB_init = 0; + auto pass = run_mfma_test(AB_init); + EXPECT_TRUE(pass); +} diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp new file mode 100644 index 0000000000..e96e1b0b29 --- /dev/null +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -0,0 +1,567 @@ +#pragma once + +#include "ck/ck.hpp" + +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +namespace ck { + +// MFMA instructions supported in this test +enum class MFMA_F8F6F4 +{ + F32_16x16x128 = + static_cast(MfmaInstr::mfma_f32_16x16x128f8f6f4), // V_MFMA_F32_16X16X128_F8F6F4 + F32_32x32x64 = + static_cast(MfmaInstr::mfma_f32_32x32x64f8f6f4) // V_MFMA_F32_32X32X64_F8F6F4 +}; + +template +struct mfma_type_selector; + +template +struct mfma_type_selector +{ + __device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) + { + auto op = mfma_type{}; + op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); + } +}; + +template +struct mfma_type_selector +{ + __device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) + { + auto op = mfma_type{}; + op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); + } +}; + +template +static constexpr int32_t vectorSize(const VecT&) +{ + return scalar_type::vector_size; +} + +// Define a load function for input A blocks: +// Size: (BLOCK_M x BLOCK_K) +// ASSUMPTION: +// - We want contiguous BLOCK_M sized column neighbors in register. +// - Data is in col_major format +// This means: +// - From A we will load K columns of size BLOCK_M to satisfy our input data +template +__device__ AFragT load_A_col_major(AType const* input_ptr) +{ + // clang-format off + // Register Mapping for 16x128: || Register Mapping for 32x64: + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector + // Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element + // Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0] + // Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1] + // Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2] + // Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3] + // Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4] + // Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5] + // Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6] + // Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7] + // Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8] + // Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9] + // Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10] + // Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11] + // Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12] + // Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13] + // Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14] + // Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15] + // Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16] + // Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17] + // Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18] + // Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19] + // Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20] + // Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21] + // Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22] + // Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23] + // Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24] + // Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25] + // Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26] + // Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27] + // Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28] + // Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29] + // Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30] + // Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31] + // clang-format on + + // Here we want to load a BLOCK_M x BLOCK_K block of data. + static constexpr uint32_t VW = vectorSize(AFragT{}); + using ARawT = typename scalar_type::type; + using AScalarFragT = vector_type::type; + + // To start the loading process, let's visualize in 2D coords. + // Each thread will load 32 elements. + // We need to know where they start, and where the next elements are. + auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row + (threadIdx.x / BLOCK_M) * VW); // Col + auto stepCoord2D = std::make_pair(0u, 1u); + + // Flatten to 1D col_major offsets. + auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; + + // BLOCK_M is a stride in A matrix + auto startOffset = col_major(startCoord2D, BLOCK_M); + auto kOffset = col_major(stepCoord2D, BLOCK_M); + + // kOffset == BLOCK_M + // This means every BLOCK_M element is loaded into output vector + auto fragA = AScalarFragT{}; +#pragma unroll VW + for(uint32_t i = 0; i < VW; i++) + { + fragA[i] = bit_cast(input_ptr[startOffset + i * kOffset]); + } + + return fragA; +} + +// Define a load function for input B blocks: +// Size: (BLOCK_K x BLOCK_N) +// ASSUMPTION: +// - We want contiguous BLOCK_N sized row neighbors in register. +// - Data is in row_major format +// This means: +// - From B we will load K rows of size BLOCK_N to satisfy our input data +template +__device__ BFragT load_B_col_major(BType const* input_ptr) +{ + // clang-format off + // Register Mapping for 128x16: || Register Mapping for 64x32: + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || N | 0 ... 31 | 0 ... 31 | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector + // Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element + // Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0] + // Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1] + // Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2] + // Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3] + // Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4] + // Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5] + // Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6] + // Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7] + // Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8] + // Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9] + // Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10] + // Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11] + // Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12] + // Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13] + // Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14] + // Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15] + // Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16] + // Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17] + // Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18] + // Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19] + // Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20] + // Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21] + // Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22] + // Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23] + // Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24] + // Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25] + // Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26] + // Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27] + // Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28] + // Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29] + // Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30] + // Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31] + // clang-format on + + // Here we want to load a BLOCK_K x BLOCK_N block of data. + static constexpr uint32_t VW = vectorSize(BFragT{}); + + // To start the loading process, let's visualize in 2D coords. + // Each thread will load 32 elements. + // We need to know where they start, and where the next elements are. + auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW, // Row + threadIdx.x % BLOCK_N); // Col + + // Flatten to 1D col_major offsets. + auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; + + auto startOffset = col_major(startCoord2D, BLOCK_K); + + auto const* fragPtr = reinterpret_cast(input_ptr + startOffset); + return *fragPtr; +} + +// Define a store function for C +// Size: (BLOCK_M x BLOCK_N) +// ASSUMPTION: +// - We want contiguous BLOCK_N sized row neighbors in register. +// - Data is in col_major format +// This means: +// - From C we will load BLOCK_M rows of size BLOCK_N to satisfy our input data +template +struct store_C_col_major; + +// Here we want to store a 16x16 block of data. +// +// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | +// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | +// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector +// Register Element ------------ ------------- ------------ -------------- Element +// Reg0 | M0 | M4 | M8 | M12 | v[0] +// Reg1 | M1 | M5 | M9 | M13 | v[1] +// Reg2 | M2 | M6 | M10 | M14 | v[2] +// Reg3 | M3 | M7 | M11 | M15 | v[3] +template +struct store_C_col_major +{ + __device__ void operator()(CType* output, CFragT cFrag) + { + static constexpr uint32_t VW = vectorSize(cFrag); // 4 + static constexpr uint32_t Dim = 16; + + // Each thread will load 4 elements. + // We need to know where they start, and where the next elements are. + auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row + threadIdx.x % Dim); // Col + + // Flatten to 1D col_major offsets. + auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; + + auto startOffset = col_major(startCoord2D, 16); + + auto* fragPtr = reinterpret_cast(output + startOffset); + *fragPtr = cFrag; + } +}; + +// Here we want to store a 32x32 block of data. +// Register Mapping: + +// Size | BLOCK_N | BLOCK_N | +// N | 0 ... 31 | 0 ... 31 | +// Thread Id | 0 ... 31 | 32 ... 63 | Vector +// Register Element ------------ ------------- Element +// Reg0 | M0 | M4 | v[0] +// Reg1 | M1 | M5 | v[1] +// Reg2 | M2 | M6 | v[2] +// Reg3 | M3 | M7 | v[3] +// ____________ _____________ +// Reg4 | M8 | M12 | v[4] +// Reg5 | M9 | M13 | v[5] +// Reg6 | M10 | M14 | v[6] +// Reg7 | M11 | M15 | v[7] +// ____________ _____________ +// Reg8 | M16 | M20 | v[8] +// Reg9 | M17 | M21 | v[9] +// Reg10 | M18 | M22 | v[10] +// Reg11 | M19 | M23 | v[11] +// ____________ _____________ +// Reg12 | M24 | M28 | v[12] +// Reg13 | M25 | M29 | v[13] +// Reg14 | M26 | M30 | v[14] +// Reg15 | M27 | M31 | v[15] + +template +struct store_C_col_major +{ + __device__ void operator()(CType* output, CFragT cFrag) + { + static constexpr uint32_t WAVE_SIZE = 64; + static constexpr uint32_t VW = 4; + static constexpr uint32_t Dim = 32; + static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8 + + auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row + threadIdx.x % Dim); // Col + + // Major step between 'chunks' + auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0); + + // Flatten to 1D col_major offsets. + auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; + + auto startOffset = col_major(startCoord2D, 32); + auto kMajorOffset = col_major(majorStepCoord2D, 32); // 8 + + // we can vector store 4 contiguous elements at a time. + using CRawT = typename scalar_type::type; + using CScalarFragT = vector_type::type; + union + { + CFragT frag; + CScalarFragT chunks[vectorSize(CFragT{}) / VW]; + } fragC{cFrag}; // Initialize with input fragment + + *(reinterpret_cast(output + startOffset)) = fragC.chunks[0]; + *(reinterpret_cast(output + startOffset + kMajorOffset)) = fragC.chunks[1]; + *(reinterpret_cast(output + startOffset + 2 * kMajorOffset)) = + fragC.chunks[2]; + *(reinterpret_cast(output + startOffset + 3 * kMajorOffset)) = + fragC.chunks[3]; + } +}; + +template +__global__ void matmul(const AType* a, const BType* b, CType* c) +{ + constexpr int WAVE_SIZE = 64; + assert(threadIdx.x < WAVE_SIZE); + assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); + + using AFragT = vector_type::type; + using BFragT = vector_type::type; + using CFragT = vector_type::type; + using AccumFragT = vector_type; + using RawAccumFragT = vector_type::type; + + // Create frags + auto fragA = AFragT{}; + auto fragB = BFragT{}; + auto fragC = CFragT{}; + auto fragAcc = AccumFragT{0}; + + // Load the inputs. + // A = col major, BLOCK_M x BLOCK_K + fragA = load_A_col_major(a); + // B = col major, BLOCK_K x BLOCK_N + fragB = load_B_col_major(b); + + // Matrix multiply-accumulate using MFMA units + // Accumulation intermediate = BLOCK_M x BLOCK_N + mfma_type_selector{}(fragA, fragB, fragAcc); + + for(int i = 0; i < vectorSize(fragC); ++i) + { + fragC[i] = type_convert(fragAcc.template AsType()[Number<0>{}][i]); + } + + auto storeC = store_C_col_major{}; + storeC(c, fragC); +} +/** + * @brief Structure to hold dimension parameters for GEMM tensors. + * + * M Number of rows in matrix A and matrix C. + * N Number of columns in matrix B and matrix C. + * K Number of columns in matrix A and number of rows in matrix B. + * StrideA Stride (leading dimension) of matrix A. + * StrideB Stride (leading dimension) of matrix B. + * StrideC Stride (leading dimension) of matrix C. + */ +struct GemmParams +{ + ck::index_t M = 16; + ck::index_t N = 16; + ck::index_t K = 128; + + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; +}; + +namespace mfma_test { +template +void RunHostGEMM(const Tensor& A, + const Tensor& B, + Tensor& C, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) +{ + auto ref_gemm = GemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); +} + +template +bool RunDeviceGEMM(KernelType kernel, + const Tensor& A, + const Tensor& B, + Tensor& C) +{ + DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); + DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(A.mData.data()); + b_n_k_device_buf.ToDevice(B.mData.data()); + kernel<<<1, 64>>>(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer())); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; +} + +template +struct TestMFMA +{ + auto PrepareGemmTensors(const GemmParams& params, index_t init) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor b_n_k( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + switch(init) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{0.015625f}); + // NOTE: not all numbers are representable in FP8, BF8, etc. + b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + break; + case 1: + // results in C = {K} + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + break; + case 2: + // expect small round off errors + a_m_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + b_n_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + break; + case 3: + // expect small round off errors + a_m_k.GenerateTensorValue(GeneratorTensor_4(-1, 3)); + b_n_k.GenerateTensorValue(GeneratorTensor_4(1, 3)); + break; + default: + // all initial values are representable in FP8, BF8 + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); + b_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); + + break; + } + + return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result); + } + + auto operator()(const DeviceMFMA& mfma_kernel, index_t init) + { + std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name + << ", CLayout = " << CLayout{}.name << std::endl; + + // Arrange + GemmParams params; + params.M = BLOCK_M; + params.N = BLOCK_N; + params.K = BLOCK_K; + + auto f_get_default_stride = [](std::size_t row, + std::size_t col, + ck::index_t stride, + auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + params.StrideA = f_get_default_stride(BLOCK_M, BLOCK_K, params.StrideA, ALayout{}); + params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{}); + params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{}); + + auto host_tensors = PrepareGemmTensors(params, init); + + const Tensor& a = std::get<0>(host_tensors); + const Tensor& b = std::get<1>(host_tensors); + Tensor& c_host = std::get<2>(host_tensors); + Tensor& c_device = std::get<3>(host_tensors); + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = PassThrough{}; + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + RunHostGEMM(a, b, c_host, a_element_op, b_element_op, c_element_op); + + RunDeviceGEMM(mfma_kernel, a, b, c_device); + + bool res = false; + if constexpr(std::is_same::value || + std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else + { + std::cout << "UNSUPPORTED CDataType" << std::endl; + } + + return res; + } +}; + +} // namespace mfma_test +} // namespace ck diff --git a/test/smfmac_op/smfmac_op_xdl.cpp b/test/smfmac_op/smfmac_op_xdl.cpp index 0e17401b0d..3cf2fb73b9 100644 --- a/test/smfmac_op/smfmac_op_xdl.cpp +++ b/test/smfmac_op/smfmac_op_xdl.cpp @@ -40,7 +40,7 @@ class TestSmfmac : public ::testing::Test void Run() { bool pass = true; - if(ck::get_device_name() == "gfx942") + if(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950") { constexpr auto matmul_default = ck::smfmac_op_util::matmul