diff --git a/Dockerfile b/Dockerfile index 2873a8500b..17800d92d5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -85,6 +85,11 @@ RUN pip install --upgrade cmake==3.27.5 && \ gunzip /usr/local/bin/ninja.gz && \ chmod a+x /usr/local/bin/ninja && \ git clone https://github.com/nico/ninjatracing.git && \ +#Install ClangBuildAnalyzer + git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \ + cd ClangBuildAnalyzer/ && \ + make -f projects/make/Makefile && \ + cd / && \ #Install latest cppcheck git clone https://github.com/danmar/cppcheck.git && \ cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . && \ diff --git a/Jenkinsfile b/Jenkinsfile index a40bd97f3a..29aec8e709 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -288,7 +288,7 @@ def cmake_build(Map conf=[:]){ if(!setup_args.contains("NO_CK_BUILD")){ if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ echo "running ninja build trace" - setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake -G Ninja ${setup_args} .. ") + setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """) build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}") } else{ @@ -316,7 +316,10 @@ def cmake_build(Map conf=[:]){ if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){ if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log" archiveArtifacts "ck_build_trace.json" + archiveArtifacts "clang_build_analysis.log" // do not run unit tests when building instances only if(!params.BUILD_INSTANCES_ONLY){ sh "ninja test" diff --git a/client_example/11_grouped_conv_bwd_weight/README.md b/client_example/11_grouped_conv_bwd_weight/README.md index ed3dff0f1e..834fd62c8f 100644 --- a/client_example/11_grouped_conv_bwd_weight/README.md +++ b/client_example/11_grouped_conv_bwd_weight/README.md @@ -36,10 +36,10 @@ Table of supported cases by instance factory with XDL instruction: | |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| |-------|---|---|---| -|bf16|2D, 3D|✗|✗| +|bf16|2D, 3D|2D, 3D|✗| |bf16(fp32 for weight)|2D, 3D|✗|1D, 2D, 3D| -|fp16 |2D, 3D|✗|1D, 2D, 3D| -|fp32 |2D, 3D|✗|1D, 2D, 3D| +|fp16 |2D, 3D|2D, 3D|1D, 2D, 3D| +|fp32 |2D, 3D|2D, 3D|1D, 2D, 3D| Table of supported cases by instance factory with WMMA instruction: diff --git a/cmake/ClangTidy.cmake b/cmake/ClangTidy.cmake index cf77991a64..d0d30d669a 100644 --- a/cmake/ClangTidy.cmake +++ b/cmake/ClangTidy.cmake @@ -144,7 +144,7 @@ function(clang_tidy_check TARGET) # COMMAND ${CLANG_TIDY_COMMAND} $, > foreach(SOURCE ${SOURCES}) if((NOT "${SOURCE}" MATCHES "(h|hpp|hxx)$") AND (NOT "${SOURCE}" MATCHES "TARGET_OBJECTS")) - string(MAKE_C_IDENTIFIER "${SOURCE}" tidy_file) + string(MD5 tidy_file "${SOURCE}") set(tidy_target tidy-target-${TARGET}-${tidy_file}) add_custom_target(${tidy_target} # for some targets clang-tidy not able to get information from .clang-tidy diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp index f5c7013698..80f7e95d30 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp @@ -261,7 +261,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument)) + if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || + ck::get_device_name() != "gfx950") { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp index a8101587e8..7b72461dd9 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp @@ -240,7 +240,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument)) + if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || + ck::get_device_name() != "gfx950") { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 95fd8bace8..38b42fefc4 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -3,14 +3,14 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_mult add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) -add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) +# add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) list(APPEND gpu_list gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) + # add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp) set(target 1) endif() diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 93770684df..9e95c3e007 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -1,5 +1,10 @@ add_custom_target(example_gemm_mx) -add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) +add_example_executable(example_gemm_mx_fp8_e8m0_scale gemm_mx_fp8_e8m0_scale.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_e8m0_scale) +add_example_executable(example_gemm_mx_fp8_fp8_scale gemm_mx_fp8_fp8_scale.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp8_scale) + +add_example_executable(example_gemm_mx_fp8_fp16_scale gemm_mx_fp8_fp16_scale.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp16_scale) diff --git a/example/67_gemm_microscaling/README.md b/example/67_gemm_microscaling/README.md index c0a0972db6..713902588d 100644 --- a/example/67_gemm_microscaling/README.md +++ b/example/67_gemm_microscaling/README.md @@ -2,16 +2,24 @@ ## example_gemm_mx_fp8 +Custom verification parameters: ```bash # arg1: verification (0=no, 1=CPU) -# arg2: initialization (0=no init, 1=integer value, 2=decimal value) +# arg2: initialization (0=constant values, 1=integer values, 2=decimal values) # arg3: time kernel (0=no, 1=yes) # arg4: verbosity (0=no info, 1=verbose info) -# arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC -./bin/example_gemm_mx_fp8 1 1 0 1 +# arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC +# arg11: KBatch +./bin/example_gemm_mx_fp8_e8m0_scale 1 1 0 1 ``` +Custom tensor shapes: ```bash -# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0 -./bin/example_gemm_mx_fp8 +./bin/example_gemm_mx_fp8_fp16_scale 1 2 1 0 128 128 64 -1 -1 -1 1 +``` + +Default invocation: +```bash +# Implies: ./bin/example_gemm_mx_fp8_fp8_scale 1 2 0 0 +./bin/example_gemm_mx_fp8_fp8_scale ``` \ No newline at end of file diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 30f03cb53b..9a05954c73 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -9,20 +9,17 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/utility/blkgemmpipe_scheduler.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/sequence.hpp" - #include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" - #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/fill.hpp" #include "ck/library/utility/host_tensor.hpp" -using ScaleDataType = ck::e8m0_bexp_t; - template using S = ck::Sequence; @@ -31,16 +28,19 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ck::type_convert; + struct ExecutionConfig final { int do_verification = 1; // (0=no, 1=CPU) - int init_method = 2; // (0=no init, 1=integer value, 2=decimal value) + int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values) bool time_kernel = false; // (0=no, 1=yes) int verbosity = 0; // (0=no info, 1=verbose info) }; -struct ProblemSize final +struct ProblemSizeSplitK final { + ck::index_t M = 3840; ck::index_t N = 4096; ck::index_t K = 4096; @@ -48,9 +48,14 @@ struct ProblemSize final ck::index_t StrideA = -1; ck::index_t StrideB = -1; ck::index_t StrideC = -1; + + ck::index_t KBatch = 1; }; -bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeSplitK& problem_size, + ExecutionConfig& config) { if(argc == 1) { @@ -63,7 +68,7 @@ bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, Execution config.time_kernel = std::stoi(argv[3]); config.verbosity = std::stoi(argv[4]); } - else if(argc == 11) + else if(argc >= 11) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); @@ -77,15 +82,21 @@ bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, Execution problem_size.StrideA = std::stoi(argv[8]); problem_size.StrideB = std::stoi(argv[9]); problem_size.StrideC = std::stoi(argv[10]); + + if(argc >= 12) + { + problem_size.KBatch = std::stoi(argv[11]); + } } else { std::cerr << "arg1: verification (0=no, 1=CPU)" << std::endl - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << "arg2: initialization (0=constant values, 1=integer values, 2=decimal values)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl - << "arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC" << std::endl; + << "arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC" << std::endl + << "arg11: KBatch" << std::endl; return false; } @@ -99,56 +110,70 @@ template -bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config) { - using ELayout = CLayout; - using DsLayout = ck::Tuple<>; - using DsDataType = ck::Tuple<>; - using AElementOp = PassThrough; - using BElementOp = PassThrough; - using CDEElementOp = CElementWiseOp; - static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; - static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; -#if 1 - // XXX: These parameters should not exist in MX-native GEMM kernel - static constexpr ck::index_t Scale_Block_M = 128; - static constexpr ck::index_t Scale_Block_N = 128; -#endif - static constexpr ck::index_t Scale_Block_K = MXVectorSize; + static constexpr ck::index_t ScaleBlockSize = MXVectorSize; - // XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize MX-specific MFMA - // instructions. - // - // XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize device-optimized - // scaled type convert functions. - // - // XXX: In DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3, KPerBlock is expected to be equal to - // ScaleBlockK (aka MXVectorSize). - // Additionally, the following is also expected: - // static_assert(ScaleBlockM % MPerBlock == 0); - // static_assert(ScaleBlockN % NPerBlock == 0); - // In MX-native GEMM kernel these requirements should be relaxed. - // - // XXX: It appears, by default we are using mfma_f32_16x16x4xf32 - // MfmaSelector::selected_mfma.k_per_blk = - // MfmaSelector::selected_mfma.k_per_blk = mfma_f32_16x16x4xf32 - // XXX: GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 assumes scale type is float - - // clang-format off - using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 - // ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB| - // ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | | - // ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | | - // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, XDataType, BDataType, XDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPSched, BlkGemmPVer, float, float, float, float>; - // clang-format on + static constexpr ck::index_t KPerBlock = 64; + using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + MXVectorSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + false, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + false, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; auto M = problem_size.M; auto N = problem_size.N; @@ -156,6 +181,7 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) auto StrideA = problem_size.StrideA; auto StrideB = problem_size.StrideB; auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; auto f_host_tensor_descriptor = [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { @@ -191,21 +217,26 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); - if(K % Scale_Block_K != 0) + if(K % ScaleBlockSize != 0) { - throw std::runtime_error("wrong! K must be multiple of Scale_Block_K (16 or 32)"); + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); }; - auto Scale_Stride_AM = f_get_default_stride(M, K / Scale_Block_K, StrideA, ALayout{}); - auto Scale_Stride_BN = f_get_default_stride(K / Scale_Block_K, N, StrideB, BLayout{}); + // Hardcode scale layouts as per pipeline assumptions + // TODO: Allow user to specify scale layouts + using AScaleLayout = Row; + using BScaleLayout = Col; - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); - Tensor a_m_k_scale( - f_host_tensor_descriptor(M, K / Scale_Block_K, Scale_Stride_AM, ALayout{})); // scales for A - Tensor b_k_n_scale( - f_host_tensor_descriptor(K / Scale_Block_K, N, Scale_Stride_BN, BLayout{})); // scales for B + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, AScaleLayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BScaleLayout{})); + + Tensor a_m_k_scale(f_host_tensor_descriptor( + M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A + Tensor b_k_n_scale(f_host_tensor_descriptor( + K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B Tensor c_m_n_host_result( f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification @@ -223,28 +254,49 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) switch(config.init_method) { - case 0: - if(config.verbosity > 0) - { - std::cout << "NOTE: No input data initialization." << std::endl; - } - break; - case 1: - case 2: + case 0: // Initializations for development and debugging ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k); - ck::utils::FillConstant{ck::type_convert(0.5f)}(a_m_k_scale); - ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n); - ck::utils::FillConstant{ck::type_convert(2.0f)}(b_k_n_scale); + ck::utils::FillConstant{ck::type_convert(2.0f)}(a_m_k_scale); + ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n); + ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); if(config.verbosity > 0) { std::cout << "Init A = {1}" << std::endl; - std::cout << "Init A scale = {0.5}" << std::endl; - std::cout << "Init B = {1}" << std::endl; - std::cout << "Init B scale = {2.0}" << std::endl; + std::cout << "Init A scale = {2.0}" << std::endl; + std::cout << "Init B = {0.5}" << std::endl; + std::cout << "Init B scale = {1.0}" << std::endl; std::cout << "Expect C = {K}" << std::endl; } break; + case 1: + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + + if constexpr(ck::is_same_v) + { + a_m_k_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + } + else + { + ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(a_m_k_scale); + ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(b_k_n_scale); + } + + break; + + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); + + b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); + break; + default: if(config.verbosity > 0) { @@ -269,31 +321,31 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) if(config.verbosity > 0) std::cout << "Done." << std::endl; - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; - constexpr ck::index_t NumDTensor = DsDataType::Size(); - - // do GEMM + // run GEMM auto device_op = DeviceOpInstance{}; auto invoker = device_op.MakeInvoker(); - auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{}, - c_device_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - std::array{}, - StrideC, - a_scale_device_buf.GetDeviceBuffer(), - b_scale_device_buf.GetDeviceBuffer(), - a_element_op, - b_element_op, - cde_element_op); + auto argument = + device_op.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); if(!device_op.IsSupportedArgument(argument)) { @@ -303,7 +355,10 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) } if(config.verbosity > 0) - std::cout << "Computing GEMM on device..." << std::endl; + { + std::cout << "Computing GEMM on device..." << std::endl << std::endl; + } + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50}); @@ -321,7 +376,7 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) BDataType, CDataType, AccDataType, - float, + XDataType, PassThrough, PassThrough, PassThrough, @@ -347,12 +402,15 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) std::cout << "Comparing results..." << std::endl; } - if(config.init_method == 1) + if(config.init_method == 0) { - res_verified = - res_verified && std::abs(static_cast(K) - c_m_n_device_result(0, 0)) <= 0.0f; - std::cout << "Expected vs Computed: " << 1.0f * K << " vs " << c_m_n_device_result(0, 0) - << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl; + auto expected = static_cast(K); + auto computed = type_convert(c_m_n_device_result(1, 12)); + + res_verified = res_verified && std::abs(expected - computed) <= 0.0f; + std::cout << "\nExpected vs Computed: " << expected << " vs " << computed + << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl + << std::endl; } res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, @@ -360,7 +418,7 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) "Error: Incorrect results!"); if(config.verbosity > 0 && res_verified) - std::cout << "Done." << std::endl; + std::cout << "Verification Successful!" << std::endl; } else { @@ -370,17 +428,18 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) if(config.time_kernel) { - std::size_t flop = std::size_t(2) * M * N * K + M * K + K * N; // GEMM + A scale + B scale + std::size_t flop = std::size_t(2) * M * N * K + + std::size_t(2) * M * N * K / ScaleBlockSize; // GEMM + A scale + B scale std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N + - sizeof(XDataType) * (M * K + K * N) / Scale_Block_K; + sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize; float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << std::endl; + << " GB/s, " << device_op.GetTypeString() << std::endl; } return res_verified; @@ -393,13 +452,15 @@ template bool run_mx_gemm_example(int argc, char* argv[]) { - ProblemSize problem_size; + ProblemSizeSplitK problem_size; ExecutionConfig config; return parse_cmd_args(argc, argv, problem_size, config) && @@ -410,7 +471,9 @@ bool run_mx_gemm_example(int argc, char* argv[]) ALayout, BLayout, CLayout, - CElementWiseOp, + AElementOp, + BElementOp, + CElementOp, AccDataType, CShuffleDataType, MXVectorSize>(problem_size, config); diff --git a/example/67_gemm_microscaling/gemm_mx_fp8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp similarity index 71% rename from example/67_gemm_microscaling/gemm_mx_fp8.cpp rename to example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp index d2e21698ec..393f4a2ea7 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp8.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp @@ -5,23 +5,22 @@ using ADataType = ck::f8_t; using BDataType = ck::f8_t; -#if 1 -// XXX: MX-native GEMM kernel will work with e8m0_bexp_t scale type -using XDataType = float; -#else + using XDataType = ck::e8m0_bexp_t; -#endif + +using CDataType = ck::half_t; using AccDataType = float; -using CShuffleDataType = float; -using CDataType = float; +using CShuffleDataType = CDataType; using ALayout = Row; using BLayout = Col; using CLayout = Row; +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix using CElementOp = PassThrough; // elementwise transformation for C matrix -constexpr ck::index_t mx_vector_size = 128; // scaling block size +constexpr ck::index_t mx_vector_size = 32; // scaling block size int main(int argc, char* argv[]) { @@ -32,6 +31,8 @@ int main(int argc, char* argv[]) ALayout, BLayout, CLayout, + AElementOp, + BElementOp, CElementOp, AccDataType, CShuffleDataType, diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp new file mode 100644 index 0000000000..dd654a8f69 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; + +using XDataType = ck::half_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t mx_vector_size = 32; // scaling block size + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp new file mode 100644 index 0000000000..c42d9783be --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; + +using XDataType = ck::f8_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t mx_vector_size = 32; // scaling block size + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index acd2ea6179..0bfba89c92 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -113,12 +113,15 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endforeach() #only continue if there are some source files left on the list if(FILE_NAME) - if(FILE_NAME MATCHES "_xdl") + if(FILE_NAME MATCHES "_xdl" AND NOT FILE_NAME MATCHES "_fp8_pk_i4") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950 list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + elseif(FILE_NAME MATCHES "_fp8_pk_i4") #only build these examples for gfx942 and gfx950 + message("trimming targets for ${FILE_NAME}") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 677ccb5ee3..6326a97f8e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -170,9 +170,9 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) if(s.log_level_ > 0) std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }} + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} ); }} diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 4ff7ede765..e5d11c6dc9 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -492,7 +492,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm continue if hdim == 192 and tile.F_bn1 == 128: # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't' or (pipeline.F_mask not in ['no', 's_no']): + if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': continue k = FmhaFwdKernel(F_idx=0, F_hdim=hdim, diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index b1f9e30178..c6d1a01792 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -253,8 +253,8 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }}, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }} + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} ); }} @@ -439,8 +439,13 @@ class FmhaFwdSplitKVCombinePipeline: pn = pad_name() n = f'{self.tag}' if pn != '' : n += f'_{pn}' + else: n += '_npad' + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' return n class FmhaFwdSplitKVApiPool: diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 700b007fad..0238a125dc 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -564,9 +564,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0, 0), + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, True, True, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, True, True, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp old mode 100644 new mode 100755 diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh index e69de29bb2..d7e5d4640a 100755 --- a/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh +++ b/example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh @@ -0,0 +1,14 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +VALID=1 + + +for b_matrix_layout in "C"; do + for m in "64" "512" "1024" "2048"; do + for n in "512" "1024" "2048"; do + for k in "64" "512" "1024" "2048"; do + $EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh b/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh index e69de29bb2..466f6bb4e1 100755 --- a/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh +++ b/example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh @@ -0,0 +1,14 @@ +#!/bin/sh +EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)" +VALID=1 + + +for b_matrix_layout in "C"; do + for m in "64" "512" "1024" "2048"; do + for n in "512" "1024" "2048"; do + for k in "64" "512" "1024" "2048"; do + $EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID + done + done + done +done \ No newline at end of file diff --git a/example/ck_tile/03_gemm/script/run_full_test.sh b/example/ck_tile/03_gemm/script/run_full_test.sh index 2448acbad2..12ea6f0bf8 100755 --- a/example/ck_tile/03_gemm/script/run_full_test.sh +++ b/example/ck_tile/03_gemm/script/run_full_test.sh @@ -32,6 +32,9 @@ function print_log_header(){ } # run verification tests +for dtype in fp16 bf16 fp8 bf8; do + example/ck_tile/03_gemm/script/benchmark_basic_$dtype.sh +done example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh # run performance benchmarks diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index 48c150009e..25598282e3 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -41,6 +41,7 @@ bool run(const ck_tile::ArgParser& arg_parser) using YDataType = DataType; using GammaDataType = DataType; using InvRmsDataType = ck_tile::null_type; + using UnquantYDataType = ck_tile::null_type; using SmoothScaleDataType = ck_tile::null_type; using YScaleDataType = ck_tile::null_type; @@ -55,6 +56,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor invRms_host_ref({m}); + ck_tile::HostTensor unquant_y_host_ref({m, n}, {stride, 1}); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); @@ -76,6 +79,7 @@ bool run(const ck_tile::ArgParser& arg_parser) using PipelineTraits = ck_tile::Rmsnorm2dFwdTraits; // fuse quant @@ -85,6 +89,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ComputeDataType, YDataType, InvRmsDataType, + UnquantYDataType, SmoothScaleDataType, YScaleDataType, Shape, @@ -108,6 +113,7 @@ bool run(const ck_tile::ArgParser& arg_parser) nullptr, nullptr, nullptr, + nullptr, epsilon, m, n, @@ -135,8 +141,9 @@ bool run(const ck_tile::ArgParser& arg_parser) GammaDataType, ComputeDataType, YDataType, - InvRmsDataType>( - x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + InvRmsDataType, + UnquantYDataType>( + x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_host_ref, epsilon); y_buf.FromDevice(y_host_dev.data()); diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index dadb2268b2..39d42e5ff1 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -54,6 +54,7 @@ template @@ -70,6 +72,7 @@ struct rmsnorm2d_fwd_traits_ using YDataType = ck_tile::remove_cvref_t; using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; + using UnquantYDataType = ck_tile::remove_cvref_t; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); @@ -120,9 +123,10 @@ struct rmsnorm2d_fwd_traits_ using Shape = ck_tile::Generic2dBlockShape; - static constexpr bool kPadN = kPadN_; - static constexpr bool kSaveInvRms = kSaveInvRms_; - static constexpr bool kTwoPass = kTwoPass_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kSaveUnquant = kSaveUnquant_; + static constexpr bool kTwoPass = kTwoPass_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; }; @@ -131,6 +135,7 @@ template @@ -145,6 +151,7 @@ using traits_ = rmsnorm2d_fwd_traits_; @@ -180,11 +188,13 @@ float rmsnorm2d_fwd_(const S& s, A a) using YDataType = typename Traits_::YDataType; using SmoothScaleDataType = typename Traits_::SmoothScaleDataType; using YScaleDataType = typename Traits_::YScaleDataType; + using UnquantYDataType = typename Traits_::UnquantYDataType; using ComputeDataType = typename RmsnormTypeConfig::ComputeDataType; using PipelineTraits = ck_tile::Rmsnorm2dFwdTraits(Traits_::kFusedAdd), static_cast(Traits_::kFusedQuant)>; @@ -195,6 +205,7 @@ float rmsnorm2d_fwd_(const S& s, A a) typename RmsnormTypeConfig::ComputeDataType, typename RmsnormTypeConfig::YDataType, typename RmsnormTypeConfig::InvRmsDataType, + typename RmsnormTypeConfig::UnquantYDataType, typename RmsnormTypeConfig::SmoothScaleDataType, typename RmsnormTypeConfig::YScaleDataType, typename Traits_::Shape, @@ -213,7 +224,16 @@ float rmsnorm2d_fwd_(const S& s, A a) using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; - using Epilogue = std::conditional_t; + using Default2DAndDynamicQuantEpilogueProblem = ck_tile::Default2DAndDynamicQuantEpilogueProblem< + ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, UnquantYDataType, typename Traits_::Shape, + ck_tile::Default2DAndDynamicQuantEpilogueTraits>; + using Default2DAndDynamicQuantEpilogue = ck_tile::Default2DAndDynamicQuantEpilogue; + + using Epilogue = std::conditional_t, + Default2DEpilogue>; using Kernel = ck_tile::Rmsnorm2dFwd; @@ -355,6 +375,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, F_YDataType : str F_SmoothScaleDataType : str F_YScaleDataType : str + F_UnquantYDataType : str F_Repeat_M : int F_Repeat_N : int F_ThreadPerBlock_M : int @@ -362,14 +383,15 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, F_Vector_N : int F_kPadN : bool F_kSaveInvRms : bool + F_kSaveUnquant: bool F_kTwoPass : bool F_kFusedAdd : int F_kFusedQuant : int @property def trait_name(self) ->str: - t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' - t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}' + t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}' t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' return t_ @@ -390,6 +412,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, F_N : str F_add : int F_sweep : int + F_saveunquant : bool instance_list : List[Any] # List[h_traits] @property @@ -401,6 +424,8 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] if self.F_sweep != 0: nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] + if self.F_saveunquant: + nnn = nnn + '_saveunquant' return nnn @property @@ -451,11 +476,11 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, if ins.F_kFusedQuant == 0: _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) elif ins.F_kFusedQuant == 1: - _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( - f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType) + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant)) elif ins.F_kFusedQuant == 2: - _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( - f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant)) _cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd, f_sweep_cond = _sweep_cond) @@ -489,67 +514,72 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant fused_add_list = [0, 1] fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant + bool_list = [False, True] - # rm rn tm tn vn pd mv 2p add sweep - h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)], - '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)], - '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)], - '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)], - '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} + # rm rn tm tn vn pd mv unquant 2p add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0)], + '640' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N - for dtype, scale_type, fused_add, fused_quant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list): + for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list): prec_i, prec_o = dtype.split(',') scale_sm, scale_y = scale_type.split(',') if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2: continue # skip non dynamic quant case if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big': continue + if (fused_quant == 0 and save_unquant == True): + continue # save_unquant should always be false when there is no quant enabled current_hs = list() for chs_ in hs: h_ = copy.copy(chs_) # copy the base instance out @@ -557,12 +587,14 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, h_.F_YDataType = prec_o h_.F_SmoothScaleDataType = scale_sm h_.F_YScaleDataType = scale_y + h_.F_UnquantYDataType = prec_i h_.F_kFusedAdd = fused_add h_.F_kFusedQuant = fused_quant + h_.F_kSaveUnquant = save_unquant current_hs.append(h_) # + "\n" #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ current_n_str = 'big' if hs_key == 'big' else current_n - total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) + total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, current_hs)) return total_blob def list_blobs(self) -> None: diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index cdee6dfb80..d5be4384ab 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -38,6 +38,7 @@ auto create_args(int argc, char* argv[]) .insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n") .insert("e", "1e-5", "epsilon") .insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case") + .insert("save_unquant", "0", "save result before quant") .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") .insert("prec_i", "fp16", "input precision") @@ -61,7 +62,8 @@ template + bool SaveRms, + bool SaveUnquant> bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::index_t m = arg_parser.get_int("m"); @@ -113,6 +115,14 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } + if((fused_quant == 0) && SaveUnquant) + { + std::cout + << "save_unquant should be 0 if quant output is not enabled because it is meaningless. " + << "Output Y is what wanted." << std::endl; + return false; + } + using TypeConfig = RmsnormTypeConfig; @@ -124,6 +134,8 @@ bool run(const ck_tile::ArgParser& arg_parser) using InvRmsDataType = std::conditional_t; + using UnquantYDataType = + std::conditional_t; using ComputeDataType = typename TypeConfig::ComputeDataType; @@ -143,6 +155,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor invRms_host_ref({m}); + ck_tile::HostTensor unquant_y_host_ref({m, n}, {y_stride, 1}); + ck_tile::HostTensor unquant_y_host_dev({m, n}, {y_stride, 1}); + ck_tile::HostTensor unquant_y_null({1}); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution{-.5f, .5f}(x_residual_host); ck_tile::FillUniformDistribution{-1.f, 1.f}(sm_scale_host); @@ -155,6 +171,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem unquant_y_buf(unquant_y_host_dev.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); gamma_buf.ToDevice(gamma_host.data()); @@ -179,7 +196,8 @@ bool run(const ck_tile::ArgParser& arg_parser) << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << std::flush; - rmsnorm2d_fwd_traits traits{prec_i, prec_o, prec_sm, prec_sy, SaveRms, fused_add, fused_quant}; + rmsnorm2d_fwd_traits traits{ + prec_i, prec_o, prec_sm, prec_sy, SaveRms, SaveUnquant, fused_add, fused_quant}; rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(), fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, @@ -189,6 +207,7 @@ bool run(const ck_tile::ArgParser& arg_parser) fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr, fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr, nullptr, // p_invRms, unsupported yet + SaveUnquant ? unquant_y_buf.GetDeviceBuffer() : nullptr, epsilon, m, n, @@ -203,6 +222,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n; num_byte += SaveRms ? sizeof(InvRmsDataType) * m * n : 0; + num_byte += SaveUnquant ? sizeof(UnquantYDataType) * m * n : 0; num_byte += fused_add ? sizeof(XResidualDataType) * m * n : 0; num_byte += ((fused_quant == 1) || (fused_quant == 2)) ? sizeof(YScaleDataType) * m : 0; num_byte += (fused_quant == 1) ? sizeof(SmoothScaleDataType) * n : 0; @@ -262,21 +282,57 @@ bool run(const ck_tile::ArgParser& arg_parser) } }; - ck_tile::reference_rmsnorm2d_fwd( - x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon, dquant_functor); + auto default_and_dquant_functor = [&](int m_, auto& o_unquant_, auto& o_, auto& acc_) { + const int N = acc_.mDesc.get_lengths()[1]; + for(int n_ = 0; n_ < N; ++n_) + { + o_unquant_(m_, n_) = ck_tile::type_convert(acc_(m_, n_)); + } + + dquant_functor(m_, o_, acc_); + }; + + if constexpr(SaveUnquant) + { + ck_tile::reference_rmsnorm2d_fwd(x_host, + gamma_host, + y_host_ref, + invRms_host_ref, + unquant_y_host_ref, + epsilon, + default_and_dquant_functor); + } + else + { + ck_tile::reference_rmsnorm2d_fwd(x_host, + gamma_host, + y_host_ref, + invRms_host_ref, + unquant_y_host_ref, + epsilon, + dquant_functor); + } } else { + assert(SaveUnquant == false); ck_tile::reference_rmsnorm2d_fwd( - x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + InvRmsDataType, + ck_tile::null_type>( + x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon); } y_buf.FromDevice(y_host_dev.data()); @@ -293,6 +349,15 @@ bool run(const ck_tile::ArgParser& arg_parser) pass = ck_tile::check_err( y_host_dev, y_host_ref, std::string("\nOUT Error: Incorrect results!"), rtol, atol); + if constexpr(SaveUnquant) + { + pass &= ck_tile::check_err(unquant_y_host_dev, + unquant_y_host_ref, + std::string("\n OUT ERROR: Incorrect unquant results!"), + rtol, + atol); + } + if(fused_add == 1) { pass &= ck_tile::check_err(y_residual_host_dev, @@ -331,6 +396,23 @@ bool run(const ck_tile::ArgParser& arg_parser) rtol, atol); } + + if constexpr(SaveUnquant) + { + std::vector unquant_y_host_dev_row( + unquant_y_host_dev.begin() + i_r * y_stride, + unquant_y_host_dev.begin() + i_r * y_stride + n); + std::vector unquant_y_host_ref_row( + unquant_y_host_ref.begin() + i_r * y_stride, + unquant_y_host_ref.begin() + i_r * y_stride + n); + pass &= + ck_tile::check_err(unquant_y_host_dev_row, + unquant_y_host_ref_row, + std::string("\nOUT[") + std::to_string(i_r) + + std::string("] Error: Incorrect unquant y results!"), + rtol, + atol); + } } } @@ -350,6 +432,8 @@ bool run(const ck_tile::ArgParser& arg_parser) return pass; } +bool is_quant_data_type(const std::string& prec) { return (prec == "int8") || (prec == "fp8"); } + int main(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -373,48 +457,79 @@ int main(int argc, char* argv[]) prec_sy = "fp32"; } - int save_rms = arg_parser.get_int("save_rms"); + int save_rms = arg_parser.get_int("save_rms"); + int fused_quant = arg_parser.get_int("fquant"); + int save_unquant = + arg_parser.get_int("save_unquant") && is_quant_data_type(prec_o) && (fused_quant != 0); if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && !save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && !save_rms) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } // dynamic quant case, only in inference else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && - !save_rms) + !save_rms && !save_unquant) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && - !save_rms) + !save_rms && !save_unquant) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && - !save_rms) + !save_rms && !save_unquant) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; } else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && - !save_rms) + !save_rms && !save_unquant) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 + : -2; + } + else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms && save_unquant) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms && save_unquant) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms && save_unquant) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_rms && save_unquant) + { + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp index 566b94442d..bb4a2f5ef4 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp @@ -21,6 +21,7 @@ struct RmsnormTypeConfig float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { - // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr int kBlockPerCu = 1; - - // This part comes from the Codegen +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Memory friendly for Interwave scheduler constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t N_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 1; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + constexpr bool DoubleSmemBuffer = false; +#endif +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + // Compute friendly for Intrawave scheduler + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t M_Warp = 2; @@ -36,61 +62,232 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; + constexpr ck_tile::index_t K_Warp_Tile = 16; - using CodegenGemmShape = + constexpr bool DoubleSmemBuffer = true; +#endif + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; - using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; - using CodegenGemmTraits = - ck_tile::TileGemmTraits; - using CodegenPipelineProblem = ck_tile:: - GemmPipelineProblem; - using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::BatchedGemmKernel; + using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; - auto kargs = Kernel::MakeKernelArgs(args); + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - constexpr dim3 blocks = Kernel::BlockSize(); + float ave_time{0}; - if(!Kernel::IsSupportedArgument(kargs)) + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got " + << tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } - if(s.log_level_ > 0) + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } +#endif + } + else { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + std::ostringstream err; + err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but " + "got " + << tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); } - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; } diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index 7b7e22160a..0999c7ad3b 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,6 +9,30 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 +#endif + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#else +#error "unsupported CK_TILE_PIPELINE_DEFAULT value" +#endif + template struct BatchedGemmTypeConfig; @@ -32,19 +56,19 @@ using CDataType = Types::CDataType; auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "256", "m dimension") - .insert("n", "128", "n dimension") - .insert("k", "128", "k dimension") + arg_parser.insert("m", "512", "m dimension") + .insert("n", "1024", "n dimension") + .insert("k", "2048", "k dimension") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default") - .insert("batch_stride_a", "32768", "Batch A stride") - .insert("batch_stride_b", "16384", "Batch B stride") - .insert("batch_stride_c", "32768", "Batch C stride") - .insert("batch_count", "16", "Batch count") + .insert("batch_stride_a", "1048576", "Batch A stride") + .insert("batch_stride_b", "2097152", "Batch B stride") + .insert("batch_stride_c", "524288", "Batch C stride") + .insert("batch_count", "8", "Batch count") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 1105304e3e..16a31e519a 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -185,7 +185,6 @@ int run_batched_gemm_example_with_layouts(int argc, kbatch, n_warmup, n_repeat); - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 03d5818179..2a9903362d 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -16,85 +16,9 @@ #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" -namespace { - -struct GroupedGemmKernelParam -{ - static const bool kPadM = false; - static const bool kPadN = false; - static const bool kPadK = false; - - static const int kBlockPerCu = 1; - static const ck_tile::index_t M_Tile = 128; - static const ck_tile::index_t N_Tile = 128; - static const ck_tile::index_t K_Tile = 32; - - static const ck_tile::index_t M_Warp = 2; - static const ck_tile::index_t N_Warp = 2; - static const ck_tile::index_t K_Warp = 1; - - static const ck_tile::index_t M_Warp_Tile = 32; - static const ck_tile::index_t N_Warp_Tile = 32; - static const ck_tile::index_t K_Warp_Tile = 8; -}; - -using CodegenGemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - -using TilePartitioner = ck_tile::GemmTile1DPartitioner; - -template -using CodegenGemmTraits = ck_tile::TileGemmTraits; - -template -using CodegenPipelineProblem = - ck_tile::GemmPipelineProblem>; - -template -using CodegenGemmPipeline = - ck_tile::GemmPipelineAGmemBGmemCRegV1>; - -template -using GemmEpilogue = ck_tile::CShuffleEpilogue::kBlockSize, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GroupedGemmKernelParam::M_Warp, - GroupedGemmKernelParam::N_Warp, - GroupedGemmKernelParam::M_Warp_Tile, - GroupedGemmKernelParam::N_Warp_Tile, - GroupedGemmKernelParam::K_Warp_Tile, - CodegenPipelineProblem::TransposeC>>; - -template -using Kernel = ck_tile::GroupedGemmKernel, - GemmEpilogue>; -}; // namespace - std::size_t get_workspace_size(const std::vector& gemm_descs) { - return ::Kernel::GetWorkSpaceSize(gemm_descs); + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } template @@ -102,37 +26,265 @@ float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* p_workspace_) { - using GroupedGemmKernel = ::Kernel; +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Memory friendly for Interwave scheduler + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; - auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs); + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 1; + constexpr ck_tile::index_t K_Warp = 1; - const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs); - constexpr dim3 blocks = GroupedGemmKernel::BlockSize(); + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; - ck_tile::hip_check_error(hipMemcpyWithStream( - p_workspace_, - arguments.data(), - arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg), - hipMemcpyHostToDevice, - s.stream_id_)); + constexpr bool DoubleSmemBuffer = false; +#endif +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + // Compute friendly for Intrawave scheduler + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; - if(s.log_level_ > 0) + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = true; +#endif + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + + const dim3 grids = Kernel::GridSize(gemm_descs); + constexpr dim3 blocks = Kernel::BlockSize(); + + ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(p_workspace_), + gemm_descs.size())); + return ave_time; + }; + + if(has_hot_loop) { - std::cout << "Launching kernel: " << GroupedGemmKernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got " + << tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } +#endif + } + else + { + std::ostringstream err; + err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but " + << "got " << tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); } - float ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - GroupedGemmKernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(p_workspace_), - gemm_descs.size())); return ave_time; } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 14d450034d..4fec329c2f 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -9,6 +9,30 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 +#endif + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#else +#error "unsupported CK_TILE_PIPELINE_DEFAULT value" +#endif + template struct GemmTypeConfig; @@ -29,7 +53,7 @@ using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; -using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; +using grouped_gemm_kargs = ck_tile::GemmHostArgs; auto create_args(int argc, char* argv[]) { @@ -46,7 +70,7 @@ auto create_args(int argc, char* argv[]) .insert("validate", "1", "0. No validation, 1. Validation on CPU.") .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") - .insert("group_count", "16", "group count."); + .insert("group_count", "8", "group count."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 080ea818c9..f068510d26 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -101,8 +101,8 @@ int run_grouped_gemm_example_with_layouts(int argc, for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); - Ns.push_back(128 + 128 * i); - Ks.push_back(128 + 64 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(256 + 64 * i); stride_As.push_back(Ks[i]); stride_Bs.push_back(Ks[i]); @@ -169,7 +169,10 @@ int run_grouped_gemm_example_with_layouts(int argc, const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); - gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + // TODO Add support for kbatch > 1 in grouped gemm + static constexpr ck_tile::index_t k_batch = 1; + gemm_descs.push_back( + {p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } invoke_gemm(warmup, repeat, group_count, gemm_descs); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 539b7d9db8..8e06e9fa33 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -181,6 +181,23 @@ struct BlockwiseGemmXdlops_pipeline_base using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + /** + * @brief Constructor for BlockwiseGemmXdlops_pipeline_base. + * + * This constructor initializes the thread copy objects for matrices A and B. + * It also performs several compile-time checks to ensure the correctness of the + * matrix tile descriptors. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + * @note The constructor includes static assertions to ensure that: + * - The matrix tile descriptors for A and B are known at compile-time. + * - The number of threads in the thread block matches the product of MWaves, NWaves, and + * WaveSize. + * - The dimensions of the block are divisible by the product of the corresponding XDL and + * repeat dimensions. + */ __host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), Tuple4 b_origin = CalculateBThreadOriginDataIndex()) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp new file mode 100644 index 0000000000..24f6afc381 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp" + +namespace ck { + +template +constexpr auto BlockGemmMXPipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_v1_mx{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp new file mode 100644 index 0000000000..628dafb063 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp @@ -0,0 +1,617 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v1_mx +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1_mx + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + using Tuple4 = typename Base::Tuple4; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block size + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], xdlops_gemm.KPerXdlops * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], xdlops_gemm.KPerXdlops * xdlops_b_idx[I0]); + } + + /** + * @brief Constructor for BlockwiseGemmXdlops_pipeline_v1_mx. + * + * The primary purpose of this constructor is to modify default initialization of the base class + * with the origin data index suitable for microscaling. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + */ + __host__ __device__ + BlockwiseGemmXdlops_pipeline_v1_mx(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : Base(a_origin, b_origin) + { + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_assert(xdlops_gemm.mfma_instr.num_groups_per_blk * + xdlops_gemm.mfma_instr.group_size == + xdlops_gemm.GetRegSizePerXdlops(), + "Assume num_regs_per_blk == num_groups_per_blk * group_size"); + + // Prefetch a_scales + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { + auto a_scale_thread_buf_group = + make_static_buffer( + a_scale_thread_desc_group.GetElementSpaceSize()); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc_group, + make_tuple(I0, I0), + a_scale_thread_buf_group); + + static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto i) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, i)); + a_scale_thread_buf(Number{}) = + a_scale_thread_buf_group[Number{}]; + }); + // go to the next group + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0)); + }); // g + + // restore row id and advance to the next scale + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size * + xdlops_gemm.mfma_instr.num_groups_per_blk, + 1)); + }); // k0 + + // restore column id and advance to the next set of rows + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); + }); // m0 + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + + // Prefetch b_scales + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(NWaves * NPerXDL, 0)); + }); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + auto c_thread_buf_per_scale = remove_cvref_t(); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + c_thread_buf_per_scale.Clear(); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + // MFMA accumulation + // m = 1:MPerXDL + // n = 1:NPerXDL + // k = 1:KPack + // c(m,n) += a(m,k)*b(k,n) + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + + // one scale per k0 + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0)); + + static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}( + [&](auto g) { + static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}( + [&](auto r) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset( + make_tuple(m0, k0, g, r)); + + constexpr auto reg_offset = + g * xdlops_gemm.mfma_instr.group_size + r; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, reg_offset)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert( + b_scale_thread_buf[Number{}]) * + type_convert( + a_scale_thread_buf[Number{}]); + }); + }); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { + auto a_scale_thread_buf_group = + make_static_buffer( + a_scale_thread_desc_group.GetElementSpaceSize()); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc_group, + make_tuple(I0, I0), + a_scale_thread_buf_group); + + static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r)); + a_scale_thread_buf(Number{}) = + a_scale_thread_buf_group[Number{}]; + }); + // go to the next group + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0)); + }); // g + + // restore row id and advance to the next scale + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size * + xdlops_gemm.mfma_instr.num_groups_per_blk, + 1)); + }); // k0 + + // restore column id and advance to the next set of rows + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); + }); // m0 + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(NWaves * NPerXDL, 0)); + }); + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + c_thread_buf_per_scale.Clear(); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + + // one scale per k0 + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0)); + + static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { + static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r)); + + constexpr auto reg_offset = + g * xdlops_gemm.mfma_instr.group_size + r; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, reg_offset)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert( + b_scale_thread_buf[Number{}]) * + type_convert( + a_scale_thread_buf[Number{}]); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})); + + // Is used to copy data from a_scale_grid to a_scale_thread + static constexpr auto a_scale_thread_desc_group = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved here + static constexpr auto b_scale_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp new file mode 100644 index 0000000000..e89185a35c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMX : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideAScale, + ck::index_t StrideB, + ck::index_t StrideBScale, + ck::index_t StrideC, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp new file mode 100644 index 0000000000..34df9a1d7b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -0,0 +1,877 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * \brief WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types + * + * This class is a work-in-progress implementation of the XDL CShuffle V3 GEMM for + * microscale-compliant data types. + * + * Assumptions: + * - A and B data types are compliant with the OCP Microscaling Formats (MX) Specification + * - Each scale applies to ScaleBlockSize elements in K direction + * - A scale matrix is row-major + * - B scale matrix is column-major + * - Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the + * exponent will be interpreted as conventional biased Float32 exponent (E8M0) + * + * Tunable parameters. + * The CK instance includes a series of tunable template parameters to control the parallel + * granularity of the workload to achieve load balancing on different hardware platforms. These + * parameters include Block Size, M/N/K Per Block, M/N per XDL, AK1, BK1, etc. + * - Block Size determines the number of threads in the thread block. + * - M/N/K Per Block determines the size of tile that each thread block is responsible for + * calculating. + * - M/N Per XDL refers to M/N size for Instinct accelerator Matrix Fused Multiply Add (MFMA) + * instructions operating on a per-wavefront basis. + * - A/B K1 is related to the data type. It can be any value ranging from 1 to K Per Block. To + * achieve the optimal load/store performance, 128bit per load is suggested. In addition, the A/B + * loading parameters must be changed accordingly to match the A/B K1 value; otherwise, it will + * result in compilation errors. + * + * Conditions for achieving computational load balancing on different hardware platforms can vary. + * + * Serialized version of the algorithm: + * \code + * // E = A * B + C + * // Loop over E[MPerBlock,NPerBlock] tiles + * for(int mb = 0; mb < M; mb += MPerBlock){ + * for(int nb = 0; nb < N; nb += NPerBlock){ + * // initialize E[MPerBlock,NPerBlock] tile + * for(int mt = mb; mt < mb + MPerBlock; mt++){ + * for(int nt = nb; nt < nb + NPerBlock; nt++){ + * E[mt,nt] = C[mt,nt]; + * } + * } + * + * // multiply-accumulate per tile + * for(int kb = 0; kb < K; kb += KPerBlock){ + * for(int m0 = mb; m0 < mb + MPerBlock; m0 += MWaves * MPerXDL){ + * for(int n0 = nb; n0 < nb + NPerBlock; n0 += NWaves * NPerXDL){ + * for(int mw = m0; mw < m0 + MWaves * MPerXDL; mw += MPerXDL){ + * for(int nw = n0; nw < n0 + NWaves * NPerXDL; nw += NPerXDL){ + * for(int k0 = kb; k0 < kb + KPerBlock; k0 += mfma.num_input_blks*KPack){ + * // MFMA accumulation for multirate instructions + * for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPack){ + * for(int k_mfma = k_pack; k_mfma < k_pack + KPack; k_mfma += mfma.k_per_blk){ + * // MFMA instruction + * for(int m = mw; m < mw + MPerXDL; m++){ + * for(int n = nw; n < nw + NPerXDL; n++){ + * for(int k = k_mfma; k < k_mfma + mfma.k_per_blk; k++){ + * E[m,n] += A[m,k] * B[k,n]; + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + * \endcode + * + */ +template +struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMX_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + // TODO: Check if this is the right algorithm for minimum_occupancy + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + static_assert((is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v)&&(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v), + "Only microscaling formats are supported for ADataType and BDataType"); + + static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported"); + + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const AScaleDataType* p_a_scale, + const BDataType* p_b, + const BScaleDataType* p_b_scale, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_a_scale, + p_b, + p_b_scale, + p_c, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideScaleA, + ck::index_t StrideB, + ck::index_t StrideScaleB, + ck::index_t StrideC, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmMX_Xdl_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared_0, + p_shared_1, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemmMX_xdl_cshuffle_v3 +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Weight Tile Permute + constexpr index_t BK01 = KPerBlock / BK1Value; + // const index_t BK00 = BK0 / BK01; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideScaleA{StrideScaleA_}, + StrideB{StrideB_}, + StrideScaleB{StrideScaleB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " + << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideScaleA; + index_t StrideB; + index_t StrideScaleB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const AScaleDataType* p_a_scale_grid_, + const BDataType* p_b_grid_, + const BScaleDataType* p_b_scale_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : Problem{M_, + N_, + K_, + StrideA_, + StrideScaleA_, + StrideB_, + StrideScaleB_, + StrideC_, + k_batch_}, + p_a_grid{p_a_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_grid{p_b_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const AScaleDataType* p_a_scale_grid; + const BDataType* p_b_grid; + const BScaleDataType* p_b_scale_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = k_id * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = k_id * k0_offset / BPackedSize; + } + } + + // Calculate A scale offset + if constexpr(is_same_v) + { + a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize; + } + else if constexpr(is_same_v) + { + a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize * karg.StrideScaleA; + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + b_scale_k_split_offset = k_id * (karg.KRead / ScaleBlockSize) * karg.StrideScaleB; + } + else if constexpr(is_same_v) + { + b_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize; + } + + if(k_id < (karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = k_id * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t a_scale_k_split_offset; // New member for scale matrix offset + index_t b_scale_k_split_offset; // New member for scale matrix offset + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + static_assert(KPerBlock % ScaleBlockSize == 0, + "KPerBlock should be multiple of ScaleBlockSize"); + + static_assert(KPerBlock / ScaleBlockSize == BlockwiseGemmPipe::KRepeat, + "Single call to xdlops_gemm::Run should process exactly ScaleBlockSize " + "elements in k dimension"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // A Scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + static constexpr auto KPerThread = KPerBlock / K0PerXdlops; + + // NXdlPerWave == NRepeat + // MXdlPerWave == MRepeat + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + // Initial thread mapping for MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MWaves=NWaves=2 + // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] + // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] + // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] + // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] + + auto a_thread_offset_m = + MPerXdl * ((get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) / MWaves) + + mfma.selected_mfma.group_size * + ((get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / MPerXdl); + auto a_thread_offset_k = KPerThread * (get_thread_local_1d_id() % MPerXdl) / MPerXdl; + + auto b_thread_offset_n = + get_thread_local_1d_id() % NPerXdl + + (get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl; + auto b_thread_offset_k = KPerThread * (get_thread_local_1d_id() % NPerXdl) / NPerXdl; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), // SrcDesc + decltype(BlockwiseGemmPipe::a_scale_thread_desc_group), // DstDesc + Sequence, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 0, // SrcVectorDim + 1, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, + a_thread_offset_k / ScaleBlockSize)); + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, BlockwiseGemmPipe::KRepeat>, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 1, // SrcVectorDim + BlockwiseGemmPipe::KRepeat, // SrcScalarPerVector + 1, + false>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, + b_thread_offset_k / ScaleBlockSize)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // A Scale grid + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(problem.M, math::integer_divide_ceil(problem.K, ScaleBlockSize)), + make_tuple(problem.StrideScaleA, 1)); + + // B Scale grid transposed + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)), + make_tuple(problem.StrideScaleB, 1)); + + Run(p_a_grid, + p_a_scale_grid, + p_b_grid, + p_b_scale_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + a_scale_grid_desc_am_ak, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + ignore = p_a_scale_grid; + ignore = a_scale_grid_desc_am_ak; + + // TODO: Implement 2 LDS version + static_assert(false, "Not implemented"); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // B scale + static constexpr auto mfma = + MfmaSelector{}; + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + static constexpr auto KPerThread = KPerBlock / K0PerXdlops; + + const index_t ScaleSliceSizeN = NXdlPerWave; + static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockSize - 1) / ScaleBlockSize; + static constexpr auto KBlockScaleSliceSizeK = + (KPerBlock + ScaleBlockSize - 1) / ScaleBlockSize; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + auto b_thread_offset_n = + get_thread_local_1d_id() % NPerXdl + + (get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl; + auto b_thread_offset_k = + (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / NPerXdl * KPerThread; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, + b_thread_offset_k / ScaleBlockSize)); + + constexpr auto b_scale_thread_slice_copy_step = + make_tuple(make_multi_index(NWaves * NPerXdl, 0), + make_multi_index(-NPerBlock, 0), + make_multi_index(-NPerBlock, KBlockScaleSliceSizeK)); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)), + make_tuple(problem.StrideScaleB, 1)); + + Run_2Lds(p_a_grid, + p_b_grid, + p_b_scale_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 92f9fc0b6e..73522b1dde 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1655,14 +1655,18 @@ struct GridwiseMoeGemm CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; + StaticallyIndexedArray scatter_offsets; + StaticallyIndexedArray scatter_weights; //= for topk auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - IndexType token_offset = fused_token & 0xffffff; + index_t token_offset = fused_token & 0xffffff; + float weight = token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] + : 0.0; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); @@ -2150,7 +2154,9 @@ struct GridwiseMoeGemm CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; + StaticallyIndexedArray + scatter_offsets; //= p_sorted_token_ids[c_token_pos]; + StaticallyIndexedArray scatter_weights; //= for topk auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = @@ -2158,6 +2164,9 @@ struct GridwiseMoeGemm static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; + float weight = token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] + : 0.0; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index a7efea277f..0310fe37a0 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -189,15 +189,36 @@ struct ThreadwiseTensorSliceTransfer_v1r3 const ElementwiseOperation element_op_; }; // namespace ThreadwiseTensorSliceTransfer_v1r3 -// Assume: -// 1. src: -// 1. SrcDesc is not known at compile-time -// 2. SrcBuffer is DynamicBuffer -// 3. src_slice_origin_idx is not known at compile-time -// 2. dst: -// 1. DstDesc is known at compile-time -// 2. DstBuffer is StaticBuffer -// 3. dst_slice_origin_idx is known at compile-time +/** + * @brief Helper structure that facilitates transfer of source (grid) data to destination threads. + * + * @details The following assumptions are made: + * - For Source (Grid) Data: + * 1. The source tensor descriptor SrcDesc is not known at compile-time. + * 2. The source buffer is a dynamic buffer. + * 3. The source slice origin index src_slice_origin_idx is not known at compile-time. + * - For Destination (Thread) Data: + * 1. The destination tensor descriptor DstDesc is known at compile-time. + * 2. The destination buffer dst_buf is a static buffer. + * 3. The destination slice origin index dst_slice_origin_idx is known at compile-time. + * + * @tparam SrcData The data type of the source tensor. + * @tparam DstData The data type of the destination tensor. + * @tparam SrcDesc The descriptor type of the source tensor. + * @tparam DstDesc The descriptor type of the destination tensor. + * @tparam SliceLengths The lengths of the slice to be transferred. + * @tparam DimAccessOrder The order of dimension access for the space-filling curve. + * @tparam SrcVectorDim The dimension along which vectorized access is performed in the source + * tensor. + * @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor. + * @tparam SrcScalarStrideInVector The stride of scalar elements within a vector in the source + * tensor. + * @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run + * or rolled back one step in MoveSrcSliceWindow + * @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for + * floating-point types). + * + */ template {}([&](auto i) { using dst_vector_t = typename remove_cvref_t::type; - IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset()); + auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); - // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], - // dst_coords_[i]); constexpr InMemoryDataOperationEnum DstInMemOp = static_cast(DstInMemOps::At(i.value)); dst_bufs(i).template Update( diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 9f6c0b5648..a638ca8608 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -793,7 +793,7 @@ struct mfma_type static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? static constexpr index_t m_per_blk = 32; // from the instruction static constexpr index_t n_per_blk = 32; // from the instruction - static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks static constexpr bool is_k_reduction = true; // ??? // clang-format on @@ -817,7 +817,7 @@ struct mfma_type static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? static constexpr index_t m_per_blk = 16; // from the instruction static constexpr index_t n_per_blk = 16; // from the instruction - static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks static constexpr bool is_k_reduction = true; // ??? // clang-format on @@ -841,7 +841,7 @@ struct mfma_type static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? static constexpr index_t m_per_blk = 32; // from the instruction static constexpr index_t n_per_blk = 32; // from the instruction - static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks static constexpr bool is_k_reduction = true; // ??? // clang-format on @@ -870,7 +870,7 @@ struct mfma_type static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? static constexpr index_t m_per_blk = 16; // from the instruction static constexpr index_t n_per_blk = 16; // from the instruction - static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks static constexpr bool is_k_reduction = true; // ??? // clang-format on diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index 8df0d885b9..0ddfd0a7c8 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -106,9 +106,10 @@ struct TransformConvBwdDataToGemm_v1 } else { - // Not possible to support even after split N. - // Too large tensor. - return N; + // Split Convolution's N dimension into N workgroups. However + // this still might not result in sufficiently small tensor, + // but at least later on we could divide the image as well. + return 1; } } else diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 3db94deccb..c291f3994c 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -83,9 +83,10 @@ struct TransformConvFwdToGemm } else { - // Not possible to support even after split N. - // Too large tensor. - return N; + // Split Convolution's N dimension into N workgroups. However + // this still might not result in sufficiently small tensor, + // but at least later on we could divide the image as well. + return 1; } } else diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 429ba44b89..5c80c42d6c 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -243,7 +243,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) #if CK_FP8_CVT_FAST_PATH template -static __device__ float cast_to_f32_from_f8(fp8_storage_t v) +static __host__ __device__ float cast_to_f32_from_f8(fp8_storage_t v) { union { diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index de59f200f0..0ed60df2c3 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -5,7 +5,7 @@ #define CK_AMD_INLINE_ASM_HPP #include "c_style_pointer_cast.hpp" -#include "data_type.hpp" +#include "dtype_vector.hpp" // TODO: deprecate all amd_assembly_outer_product_xxx diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 396e375d8c..0d4611becc 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/dtype_fp64.hpp" namespace ck { // Define the common macro for MI300 models diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index b25ab5ab5f..9732739994 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -346,53 +346,6 @@ inline constexpr bool is_native_type() is_same::value || is_same::value; } -// vector_type -template -struct vector_type; - -// Caution: DO NOT REMOVE -// intentionally have only declaration but no definition to cause compilation failure when trying to -// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of -// vectors" -template -struct vector_type; - -// Caution: DO NOT REMOVE -// intentionally have only declaration but no definition to cause compilation failure when trying to -// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of -// vectors" -template -struct vector_type, N>; - -// vector_type_maker -// This is the right way to handle "vector of vectors": making a bigger vector instead -template -struct vector_type_maker -{ - using type = vector_type; -}; - -template -struct vector_type_maker -{ - using type = vector_type; -}; - -template -struct vector_type_maker, N0> -{ - using type = vector_type; -}; - -template -using vector_type_maker_t = typename vector_type_maker::type; - -template -__host__ __device__ constexpr auto make_vector_type(Number) -{ - return typename vector_type_maker::type{}; -} - // scalar_type template struct scalar_type; @@ -416,13 +369,6 @@ struct scalar_type static constexpr index_t vector_size = N; }; -template -struct scalar_type> -{ - using type = T; - static constexpr index_t vector_size = N; -}; - // template <> struct scalar_type @@ -517,6 +463,13 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +template <> +struct scalar_type +{ + using type = e8m0_bexp_t::type; + static constexpr index_t vector_size = 1; +}; + template <> struct scalar_type { @@ -524,2864 +477,10 @@ struct scalar_type static constexpr index_t vector_size = 1; }; -template -struct vector_type()>> -{ - using d1_t = T; - using type = d1_t; - - union - { - T d1_; - StaticallyIndexedArray d1x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value, - "Something went wrong, please check src and dst types."); - - return data_.d1x1_; - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value, - "Something went wrong, please check src and dst types."); - - return data_.d1x1_; - } -}; - -__device__ int static err = 0; -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - - using type = d2_t; - - union - { - d2_t d2_; - StaticallyIndexedArray d1x2_; - StaticallyIndexedArray d2x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d3_t __attribute__((ext_vector_type(3))); - - using type = d3_t; - - union - { - d3_t d3_; - StaticallyIndexedArray d1x3_; - StaticallyIndexedArray d2x1_; - StaticallyIndexedArray d3x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x3_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else if constexpr(is_same::value) - { - return data_.d3x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x3_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else if constexpr(is_same::value) - { - return data_.d3x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - - using type = d4_t; - - union - { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d5_t __attribute__((ext_vector_type(5))); - - using type = d5_t; - - union - { - d5_t d5_; - StaticallyIndexedArray d1x5_; - StaticallyIndexedArray d4x1_; - StaticallyIndexedArray d5x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x5_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d5x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x5_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d5x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d7_t __attribute__((ext_vector_type(7))); - - using type = d7_t; - - union - { - d7_t d7_; - StaticallyIndexedArray d1x7_; - StaticallyIndexedArray d2x3_; - StaticallyIndexedArray d4x1_; - StaticallyIndexedArray d7x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x7_; - } - else if constexpr(is_same::value) - { - return data_.d2x3_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d7x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x7_; - } - else if constexpr(is_same::value) - { - return data_.d2x3_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d7x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - - using type = d8_t; - - union - { - d8_t d8_; - StaticallyIndexedArray d1x8_; - StaticallyIndexedArray d2x4_; - StaticallyIndexedArray d4x2_; - StaticallyIndexedArray d8x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d13_t __attribute__((ext_vector_type(13))); - - using type = d13_t; - - union - { - d13_t d13_; - StaticallyIndexedArray d1x13_; - StaticallyIndexedArray d4x3_; - StaticallyIndexedArray d8x1_; - StaticallyIndexedArray d13x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x13_; - } - else if constexpr(is_same::value) - { - return data_.d4x3_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else if constexpr(is_same::value) - { - return data_.d13x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x13_; - } - else if constexpr(is_same::value) - { - return data_.d4x3_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else if constexpr(is_same::value) - { - return data_.d13x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - - using type = d16_t; - - union - { - d16_t d16_; - StaticallyIndexedArray d1x16_; - StaticallyIndexedArray d2x8_; - StaticallyIndexedArray d4x4_; - StaticallyIndexedArray d8x2_; - StaticallyIndexedArray d16x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - - using type = d32_t; - - union - { - d32_t d32_; - StaticallyIndexedArray d1x32_; - StaticallyIndexedArray d2x16_; - StaticallyIndexedArray d4x8_; - StaticallyIndexedArray d8x4_; - StaticallyIndexedArray d16x2_; - StaticallyIndexedArray d32x1_; - } data_ = {d32_t{0}}; - - __attribute__((host)) __attribute__((device)) constexpr vector_type() {} - - __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } - - // __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - // __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - - using type = d64_t; - - union - { - d64_t d64_; - StaticallyIndexedArray d1x64_; - StaticallyIndexedArray d2x32_; - StaticallyIndexedArray d4x16_; - StaticallyIndexedArray d8x8_; - StaticallyIndexedArray d16x4_; - StaticallyIndexedArray d32x2_; - StaticallyIndexedArray d64x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - typedef T d128_t __attribute__((ext_vector_type(128))); - - using type = d128_t; - - union - { - d128_t d128_; - StaticallyIndexedArray d1x128_; - StaticallyIndexedArray d2x64_; - StaticallyIndexedArray d4x32_; - StaticallyIndexedArray d8x16_; - StaticallyIndexedArray d16x8_; - StaticallyIndexedArray d32x4_; - StaticallyIndexedArray d64x2_; - StaticallyIndexedArray d128x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x128_; - } - else if constexpr(is_same::value) - { - return data_.d2x64_; - } - else if constexpr(is_same::value) - { - return data_.d4x32_; - } - else if constexpr(is_same::value) - { - return data_.d8x16_; - } - else if constexpr(is_same::value) - { - return data_.d16x8_; - } - else if constexpr(is_same::value) - { - return data_.d32x4_; - } - else if constexpr(is_same::value) - { - return data_.d64x2_; - } - else if constexpr(is_same::value) - { - return data_.d128x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x128_; - } - else if constexpr(is_same::value) - { - return data_.d2x64_; - } - else if constexpr(is_same::value) - { - return data_.d4x32_; - } - else if constexpr(is_same::value) - { - return data_.d8x16_; - } - else if constexpr(is_same::value) - { - return data_.d16x8_; - } - else if constexpr(is_same::value) - { - return data_.d32x4_; - } - else if constexpr(is_same::value) - { - return data_.d64x2_; - } - else if constexpr(is_same::value) - { - return data_.d128x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - typedef T d128_t __attribute__((ext_vector_type(128))); - typedef T d256_t __attribute__((ext_vector_type(256))); - - using type = d256_t; - - union - { - d256_t d256_; - StaticallyIndexedArray d1x256_; - StaticallyIndexedArray d2x128_; - StaticallyIndexedArray d4x64_; - StaticallyIndexedArray d8x32_; - StaticallyIndexedArray d16x16_; - StaticallyIndexedArray d32x8_; - StaticallyIndexedArray d64x4_; - StaticallyIndexedArray d128x2_; - StaticallyIndexedArray d256x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert( - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x256_; - } - else if constexpr(is_same::value) - { - return data_.d2x128_; - } - else if constexpr(is_same::value) - { - return data_.d4x64_; - } - else if constexpr(is_same::value) - { - return data_.d8x32_; - } - else if constexpr(is_same::value) - { - return data_.d16x16_; - } - else if constexpr(is_same::value) - { - return data_.d32x8_; - } - else if constexpr(is_same::value) - { - return data_.d64x4_; - } - else if constexpr(is_same::value) - { - return data_.d128x2_; - } - else if constexpr(is_same::value) - { - return data_.d256x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert( - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x256_; - } - else if constexpr(is_same::value) - { - return data_.d2x128_; - } - else if constexpr(is_same::value) - { - return data_.d4x64_; - } - else if constexpr(is_same::value) - { - return data_.d8x32_; - } - else if constexpr(is_same::value) - { - return data_.d16x16_; - } - else if constexpr(is_same::value) - { - return data_.d32x8_; - } - else if constexpr(is_same::value) - { - return data_.d64x4_; - } - else if constexpr(is_same::value) - { - return data_.d128x2_; - } - else if constexpr(is_same::value) - { - return data_.d256x1_; - } - else - { - return err; - } - } -}; - -template -struct non_native_vector_base; - -template -struct nnvb_data_t_selector -{ - using type = unsigned _BitInt(8 * sizeof(T)); -}; - -template <> -struct nnvb_data_t_selector -{ - using type = f8_ocp_t::data_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf8_ocp_t::data_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = f6x16_pk_t::type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = f6x32_pk_t::type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf6x16_pk_t::type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf6x32_pk_t::type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = pk_i4_t::type; -}; - -template -struct non_native_vector_base< - T, - N, - ck::enable_if_t> -{ - using data_t = typename nnvb_data_t_selector::type; // select data_t based on the size of T - static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); - using data_v = data_t __attribute__((ext_vector_type(N))); - using type = non_native_vector_base; - - union alignas(next_pow2(N * sizeof(T))) - { - data_v dN; // storage vector; - StaticallyIndexedArray dxN; - StaticallyIndexedArray dTxN; - StaticallyIndexedArray dNx1; - } data_; - - __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {} - __host__ __device__ constexpr non_native_vector_base(T f) - : non_native_vector_base(bit_cast(f)) - { - } - __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; - __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} - - __host__ __device__ constexpr operator data_v() const { return data_.dN; } - __host__ __device__ constexpr operator data_t() const - { - if constexpr(N == 1) - { - return data_.dxN[Number<0>{}]; - } - else - { - return data_.dxN; // XXX this should cause an error - } - } - __host__ __device__ constexpr operator T() const - { - if constexpr(N == 1) - { - return data_.dTxN[Number<0>{}]; - } - else - { - return data_.dTxN; // XXX this should cause an error - } - } - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same_v || is_same_v || is_same_v, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same_v) - { - return data_.dxN; - } - else if constexpr(is_same_v) - { - return data_.dTxN; - } - else if constexpr(is_same_v) - { - return data_.dNx1; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same_v || is_same_v || is_same_v, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same_v) - { - return data_.dxN; - } - else if constexpr(is_same_v) - { - return data_.dTxN; - } - else if constexpr(is_same_v) - { - return data_.dNx1; - } - else - { - return err; - } - } -}; - -// implementation for f6x16 and f6x32 -template -struct non_native_vector_base> -{ - using data_t = - typename nnvb_data_t_selector::type; // select data_t based on declared base type - using element_t = typename T::element_type; // select element_t based on declared element type - static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); - static constexpr size_t size_factor = - sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6 - using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); - using type = non_native_vector_base; - - union alignas(next_pow2(N * sizeof(T))) - { - data_v dN; // storage vector; - StaticallyIndexedArray dxN; - StaticallyIndexedArray dTxN; - StaticallyIndexedArray dNx1; - } data_; - - __host__ __device__ constexpr non_native_vector_base(data_t a) - : data_{data_v(a.At(Number<0>{}))} - { - } - __host__ __device__ constexpr non_native_vector_base(T f) - : non_native_vector_base(bit_cast(f)) - { - } - __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; - __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} - - __host__ __device__ constexpr operator data_v() const { return data_.dN; } - __host__ __device__ constexpr operator data_t() const - { - if constexpr(N == 1) - { - return data_.dxN[Number<0>{}]; - } - else - { - return data_.dxN; // XXX this should cause an error - } - } - __host__ __device__ constexpr operator T() const - { - if constexpr(N == 1) - { - return data_.dTxN[Number<0>{}]; - } - else - { - return data_.dTxN; // XXX this should cause an error - } - } -}; - -template -struct scalar_type>; - -template -struct scalar_type> -{ - using type = typename non_native_vector_base::data_t; - - static constexpr index_t vector_size = N; -}; - -template -struct scalar_type> -{ - using type = typename non_native_vector_base::data_t; - - static constexpr index_t vector_size = N; -}; - -template -struct scalar_type> -{ - using type = typename non_native_vector_base::data_t; - - static constexpr index_t vector_size = N; -}; - -// non-native vector_type implementation -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using type = d1_nnv_t; - - union alignas(next_pow2(1 * sizeof(T))) - { - d1_t d1_; - StaticallyIndexedArray d1x1_; - d1_nnv_t d1_nnv_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{d1_t{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - - using type = d2_t; - - union alignas(next_pow2(2 * sizeof(T))) - { - d2_t d2_; - StaticallyIndexedArray d1x2_; - StaticallyIndexedArray d2x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - - using type = d4_t; - - union alignas(next_pow2(4 * sizeof(T))) - { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - - using type = d8_t; - - union alignas(next_pow2(8 * sizeof(T))) - { - d8_t d8_; - StaticallyIndexedArray d1x8_; - StaticallyIndexedArray d2x4_; - StaticallyIndexedArray d4x2_; - StaticallyIndexedArray d8x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; - - using type = d16_t; - - union alignas(next_pow2(16 * sizeof(T))) - { - d16_t d16_; - StaticallyIndexedArray d1x16_; - StaticallyIndexedArray d2x8_; - StaticallyIndexedArray d4x4_; - StaticallyIndexedArray d8x2_; - StaticallyIndexedArray d16x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; - using d32_t = non_native_vector_base; - - using type = d32_t; - - union alignas(next_pow2(32 * sizeof(T))) - { - d32_t d32_; - StaticallyIndexedArray d1x32_; - StaticallyIndexedArray d2x16_; - StaticallyIndexedArray d4x8_; - StaticallyIndexedArray d8x4_; - StaticallyIndexedArray d16x2_; - StaticallyIndexedArray d32x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; - using d32_t = non_native_vector_base; - using d64_t = non_native_vector_base; - - using type = d64_t; - - union alignas(next_pow2(64 * sizeof(T))) - { - d64_t d64_; - StaticallyIndexedArray d1x64_; - StaticallyIndexedArray d2x32_; - StaticallyIndexedArray d4x16_; - StaticallyIndexedArray d8x8_; - StaticallyIndexedArray d16x4_; - StaticallyIndexedArray d32x2_; - StaticallyIndexedArray d64x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } -}; - -using int64_t = long; - -// fp64 -using double2_t = typename vector_type::type; -using double4_t = typename vector_type::type; - -// fp32 -using float2_t = typename vector_type::type; -using float4_t = typename vector_type::type; -using float8_t = typename vector_type::type; -using float16_t = typename vector_type::type; -using float32_t = typename vector_type::type; -using float64_t = typename vector_type::type; - -// fp16 -using half2_t = typename vector_type::type; -using half4_t = typename vector_type::type; -using half8_t = typename vector_type::type; -using half16_t = typename vector_type::type; -using half32_t = typename vector_type::type; -using half64_t = typename vector_type::type; - -// bfp16 -using bhalf2_t = typename vector_type::type; -using bhalf4_t = typename vector_type::type; -using bhalf8_t = typename vector_type::type; -using bhalf16_t = typename vector_type::type; -using bhalf32_t = typename vector_type::type; -using bhalf64_t = typename vector_type::type; - -// i32 -using int32x2_t = typename vector_type::type; -using int32x4_t = typename vector_type::type; -using int32x8_t = typename vector_type::type; -using int32x16_t = typename vector_type::type; -using int32x32_t = typename vector_type::type; -using int32x64_t = typename vector_type::type; - -// i8 -using int8x2_t = typename vector_type::type; -using int8x4_t = typename vector_type::type; -using int8x8_t = typename vector_type::type; -using int8x16_t = typename vector_type::type; -using int8x32_t = typename vector_type::type; -using int8x64_t = typename vector_type::type; - -// f8 -using f8x2_fnuz_t = typename vector_type::type; -using f8x4_fnuz_t = typename vector_type::type; -using f8x8_fnuz_t = typename vector_type::type; -using f8x16_fnuz_t = typename vector_type::type; -using f8x32_fnuz_t = typename vector_type::type; -using f8x64_fnuz_t = typename vector_type::type; - -// bf8 -using bf8x2_fnuz_t = typename vector_type::type; -using bf8x4_fnuz_t = typename vector_type::type; -using bf8x8_fnuz_t = typename vector_type::type; -using bf8x16_fnuz_t = typename vector_type::type; -using bf8x32_fnuz_t = typename vector_type::type; -using bf8x64_fnuz_t = typename vector_type::type; - -// f8 -using f8x2_ocp_t = typename vector_type::type; -using f8x4_ocp_t = typename vector_type::type; -using f8x8_ocp_t = typename vector_type::type; -using f8x16_ocp_t = typename vector_type::type; -using f8x32_ocp_t = typename vector_type::type; -using f8x64_ocp_t = typename vector_type::type; - -// bf8 -using bf8x2_ocp_t = typename vector_type::type; -using bf8x4_ocp_t = typename vector_type::type; -using bf8x8_ocp_t = typename vector_type::type; -using bf8x16_ocp_t = typename vector_type::type; -using bf8x32_ocp_t = typename vector_type::type; -using bf8x64_ocp_t = typename vector_type::type; - -#if CK_FP8_TYPE_OCP -// f8 -using f8x2_t = f8x2_ocp_t; -using f8x4_t = f8x4_ocp_t; -using f8x8_t = f8x8_ocp_t; -using f8x16_t = f8x16_ocp_t; -using f8x32_t = f8x32_ocp_t; -using f8x64_t = f8x64_ocp_t; - -// bf8 -using bf8x2_t = bf8x2_ocp_t; -using bf8x4_t = bf8x4_ocp_t; -using bf8x8_t = bf8x8_ocp_t; -using bf8x16_t = bf8x16_ocp_t; -using bf8x32_t = bf8x32_ocp_t; -using bf8x64_t = bf8x64_ocp_t; -#elif CK_FP8_TYPE_FNUZ -// f8 -using f8x2_t = f8x2_fnuz_t; -using f8x4_t = f8x4_fnuz_t; -using f8x8_t = f8x8_fnuz_t; -using f8x16_t = f8x16_fnuz_t; -using f8x32_t = f8x32_fnuz_t; -using f8x64_t = f8x64_fnuz_t; - -// bf8 -using bf8x2_t = bf8x2_fnuz_t; -using bf8x4_t = bf8x4_fnuz_t; -using bf8x8_t = bf8x8_fnuz_t; -using bf8x16_t = bf8x16_fnuz_t; -using bf8x32_t = bf8x32_fnuz_t; -using bf8x64_t = bf8x64_fnuz_t; -#endif - -// u8 -using uint8x2_t = typename vector_type::type; -using uint8x4_t = typename vector_type::type; -using uint8x8_t = typename vector_type::type; -using uint8x16_t = typename vector_type::type; -using uint8x32_t = typename vector_type::type; -using uint8x64_t = typename vector_type::type; - -// f4 -using f4x2_t = typename vector_type::type; -using f4x4_t = typename vector_type::type; -using f4x8_t = typename vector_type::type; -using f4x16_t = typename vector_type::type; -using f4x32_t = typename vector_type::type; -using f4x64_t = typename vector_type::type; - -// f6 -using f6x16_t = typename vector_type::type; -using f6x32_t = typename vector_type::type; - -// bf6 -using bf6x16_t = typename vector_type::type; -using bf6x32_t = typename vector_type::type; - -// pack int4 -using pk_i4x2_t = typename vector_type::type; -using pk_i4x4_t = typename vector_type::type; -using pk_i4x8_t = typename vector_type::type; - -#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) -template -struct NumericLimits; - -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; } - - __host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; } - - __host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; } - - __host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; } - - __host__ __device__ static constexpr int32_t QuietNaN() { return 0; } -}; -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; } - - __host__ __device__ static constexpr int16_t Min() noexcept { return -32768; } - - __host__ __device__ static constexpr int16_t Max() noexcept { return 32767; } - - __host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; } - - __host__ __device__ static constexpr int16_t QuietNaN() { return 0; } -}; - -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; } - - __host__ __device__ static constexpr int8_t Min() noexcept { return -128; } - - __host__ __device__ static constexpr int8_t Max() noexcept { return 127; } - - __host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; } - - __host__ __device__ static constexpr int8_t QuietNaN() { return 0; } -}; - -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; } - - __host__ __device__ static constexpr uint32_t Min() noexcept { return 0; } - - __host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; } - - __host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; } - - __host__ __device__ static constexpr uint32_t QuietNaN() { return 0; } -}; - -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; } - - __host__ __device__ static constexpr uint16_t Min() noexcept { return 0; } - - __host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; } - - __host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; } - - __host__ __device__ static constexpr uint16_t QuietNaN() { return 0; } -}; - -template <> -struct NumericLimits -{ - static constexpr unsigned int binary_min = 0x00800000; - static constexpr unsigned int binary_max = 0x7F7FFFFF; - static constexpr unsigned int binary_lowest = 0xFF7FFFFF; - static constexpr unsigned int binary_qnan = 0xFFC00001; - static constexpr unsigned int binary_inf = 0x7F8000000; - - __host__ __device__ static constexpr float Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr float Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr float Lowest() { return bit_cast(binary_lowest); } - - __host__ __device__ static constexpr float QuietNaN() { return bit_cast(binary_qnan); } - - __host__ __device__ static constexpr float Infinity() { return bit_cast(binary_inf); } -}; - -template <> -struct NumericLimits -{ - static constexpr unsigned short binary_min = 0x0400; - static constexpr unsigned short binary_max = 0x7BFF; - static constexpr unsigned short binary_lowest = 0xFBFF; - static constexpr unsigned short binary_qnan = 0x7FFF; - - __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } - - __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } -}; - -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr int4_t Min() { return int4_t(-8); } - - __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } - - __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); } -}; -#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - -template <> -struct NumericLimits -{ - // negative zero nan mode with exp bias = 8 - static constexpr uint8_t binary_min = 0x08; // 0b00001000 - static constexpr uint8_t binary_max = 0x7F; // 0b01111111 - static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 - static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 - // ieee mode with exp bias = 7 - // static constexpr uint8_t binary_min = 0x08; // 0b00001000 - // static constexpr uint8_t binary_max = 0x77; // 0b01110111 - // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 - // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 - - __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } - - __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } - - __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } - - __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } -}; - -template <> -struct NumericLimits -{ - // negative zero nan mode with exp bias = 16 - static constexpr uint8_t binary_min = 0x04; // 0b00000100 - static constexpr uint8_t binary_max = 0x7F; // 0b01111111 - static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 - static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 - // ieee mode with exp bias = 15 - // static constexpr uint8_t binary_min = 0x04; // 0b00000100 - // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 - // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 - // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= - - __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } - - __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } - - __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } - - __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 - static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 - static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 - static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 - - __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr f8_ocp_t Lowest() - { - return bit_cast(binary_lowest); - } - - __host__ __device__ static constexpr f8_ocp_t QuietNaN() - { - return bit_cast(binary_qnan); - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 - static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 - static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 - static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 - - __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr bf8_ocp_t Lowest() - { - return bit_cast(binary_lowest); - } - - __host__ __device__ static constexpr bf8_ocp_t QuietNaN() - { - return bit_cast(binary_qnan); - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 - static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 - static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 - static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 - static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 - - static constexpr float data_max_normal_number = 6; - static constexpr float data_min_subnormal_number = 0.5; - - __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } - __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } - __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } - __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } - __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } - - __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } - __host__ __device__ static constexpr float DataMinSubnorm() - { - return data_min_subnormal_number; - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 - static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 - static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 - static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 - static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 - - static constexpr float data_max_normal_number = 7.5; - static constexpr float data_min_subnormal_number = 0.125; - - __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } - __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } - __host__ __device__ static constexpr f6_t Lowest() - { - return f6_t(binary_lowest_normal & 0b111111); - } - __host__ __device__ static constexpr f6_t MinSubnorm() - { - return f6_t(binary_min_subnorm & 0b111111); - } - __host__ __device__ static constexpr f6_t MaxSubnorm() - { - return f6_t(binary_max_subnorm & 0b111111); - } - - __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } - __host__ __device__ static constexpr float DataMinSubnorm() - { - return data_min_subnormal_number; - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 - static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 - static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 - static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 - static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 - - static constexpr float data_max_normal_number = 28; - static constexpr float data_min_subnormal_number = 0.0625; - - __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } - __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } - __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } - __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } - __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } - - __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } - __host__ __device__ static constexpr float DataMinSubnorm() - { - return data_min_subnormal_number; - } -}; - -template <> -struct NumericLimits -{ - static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 - static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 - static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 - static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 - static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 - static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 - static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 - static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 - - __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } - __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } - __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_135() - { - return e8m0_bexp_t(binary_135); - } - __host__ __device__ static constexpr e8m0_bexp_t Binary_142() - { - return e8m0_bexp_t(binary_142); - } -}; +#if defined(_WIN32) +using int64_t = long long; #else -template -struct NumericLimits -{ - __host__ __device__ static constexpr T Min() { return std::numeric_limits::min(); } - __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } - __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } - __host__ __device__ static constexpr T QuietNaN() - { - return std::numeric_limits::quiet_NaN(); - } - __host__ __device__ static constexpr T Infinity() { return std::numeric_limits::infinity(); } -}; - -template <> -struct NumericLimits -{ - static constexpr unsigned short binary_min = 0x0400; - static constexpr unsigned short binary_max = 0x7BFF; - static constexpr unsigned short binary_lowest = 0xFBFF; - static constexpr unsigned short binary_qnan = 0x7FFF; - - __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } - - __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } -}; - -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr int4_t Min() { return int4_t(-8); } - - __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } - - __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); } -}; -#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - -template <> -struct NumericLimits -{ - // negative zero nan mode with exp bias = 8 - static constexpr uint8_t binary_min = 0x08; // 0b00001000 - static constexpr uint8_t binary_max = 0x7F; // 0b01111111 - static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 - static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 - // ieee mode with exp bias = 7 - // static constexpr uint8_t binary_min = 0x08; // 0b00001000 - // static constexpr uint8_t binary_max = 0x77; // 0b01110111 - // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 - // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 - - __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } - - __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } - - __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } - - __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } -}; - -template <> -struct NumericLimits -{ - // negative zero nan mode with exp bias = 16 - static constexpr uint8_t binary_min = 0x04; // 0b00000100 - static constexpr uint8_t binary_max = 0x7F; // 0b01111111 - static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 - static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 - // ieee mode with exp bias = 15 - // static constexpr uint8_t binary_min = 0x04; // 0b00000100 - // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 - // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 - // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= - - __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } - - __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } - - __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } - - __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 - static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 - static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 - static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 - - __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr f8_ocp_t Lowest() - { - return bit_cast(binary_lowest); - } - - __host__ __device__ static constexpr f8_ocp_t QuietNaN() - { - return bit_cast(binary_qnan); - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 - static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 - static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 - static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 - - __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr bf8_ocp_t Lowest() - { - return bit_cast(binary_lowest); - } - - __host__ __device__ static constexpr bf8_ocp_t QuietNaN() - { - return bit_cast(binary_qnan); - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 - static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 - static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 - static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 - static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 - - static constexpr float data_max_normal_number = 6; - static constexpr float data_min_subnormal_number = 0.5; - - __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } - __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } - __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } - __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } - __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } - - __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } - __host__ __device__ static constexpr float DataMinSubnorm() - { - return data_min_subnormal_number; - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 - static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 - static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 - static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 - static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 - - static constexpr float data_max_normal_number = 7.5; - static constexpr float data_min_subnormal_number = 0.125; - - __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } - __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } - __host__ __device__ static constexpr f6_t Lowest() - { - return f6_t(binary_lowest_normal & 0b111111); - } - __host__ __device__ static constexpr f6_t MinSubnorm() - { - return f6_t(binary_min_subnorm & 0b111111); - } - __host__ __device__ static constexpr f6_t MaxSubnorm() - { - return f6_t(binary_max_subnorm & 0b111111); - } - - __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } - __host__ __device__ static constexpr float DataMinSubnorm() - { - return data_min_subnormal_number; - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 - static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 - static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 - static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 - static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 - - static constexpr float data_max_normal_number = 28; - static constexpr float data_min_subnormal_number = 0.0625; - - __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } - __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } - __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } - __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } - __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } - - __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } - __host__ __device__ static constexpr float DataMinSubnorm() - { - return data_min_subnormal_number; - } -}; - -template <> -struct NumericLimits -{ - static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 - static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 - static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 - static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 - static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 - static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 - static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 - static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 - - __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } - __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } - __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_135() - { - return e8m0_bexp_t(binary_135); - } - __host__ __device__ static constexpr e8m0_bexp_t Binary_142() - { - return e8m0_bexp_t(binary_142); - } -}; +using int64_t = long; #endif -template -struct NumericUtils -{ -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 8; - static constexpr int mant = 23; - static constexpr int bias = 127; - static constexpr uint32_t nan_mask = 0x7F800000; - static constexpr uint32_t head_mask = 0xFF800000; - static constexpr uint32_t mant_mask = 0x7FFFFF; - static constexpr uint32_t exp_mask = 0xFF; - static constexpr uint32_t Inf = 0x7F800000; - static constexpr uint32_t NegInf = 0xFF800000; - static constexpr uint32_t NaN = 0x7F800001; - static constexpr uint32_t Neg0 = 0x80000000; - static constexpr bool has_inf = true; - using bitwise_type = uint32_t; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 5; - static constexpr int mant = 10; - static constexpr int bias = 15; - static constexpr uint16_t nan_mask = 0x7C00; - static constexpr uint16_t head_mask = 0xFC00; - static constexpr uint16_t mant_mask = 0x3FF; - static constexpr uint16_t exp_mask = 0x1F; - static constexpr uint32_t Inf = 0x7C00; - static constexpr uint32_t NegInf = 0xFC00; - static constexpr uint32_t NaN = 0x7C01; - static constexpr uint32_t Neg0 = 0x8000; - static constexpr bool has_inf = true; - using bitwise_type = uint16_t; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 8; - static constexpr int mant = 7; - static constexpr int bias = 128; // negative zero nan mode - // static constexpr int bias = 127; // ieee mode -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 4; - static constexpr int mant = 3; - static constexpr int bias = 8; // negative zero nan mode - // static constexpr int bias = 7; // ieee mode - static constexpr bool has_inf = false; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 5; - static constexpr int mant = 2; - static constexpr int bias = 16; // negative zero nan mode - // static constexpr int bias = 15; // ieee mode - static constexpr bool has_inf = false; -}; -template <> -struct NumericUtils -{ - static constexpr int exp = 4; - static constexpr int mant = 3; - static constexpr int bias = 7; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 5; - static constexpr int mant = 2; - static constexpr int bias = 15; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 2; - static constexpr int mant = 1; - static constexpr int bias = 1; - static constexpr uint32_t sr_shift = 10; - - static constexpr int unbiased_exp_min = 0; - static constexpr int unbiased_exp_max = 2; - static constexpr int biased_exp_min = 1; - static constexpr int biased_exp_max = 3; - - static constexpr uint8_t positive_zero_mask = 0b0000; - static constexpr uint8_t negative_zero_mask = 0b1000; - - static constexpr uint8_t one_mask = 0b0010; - static constexpr uint8_t set_sign_mask = 0b0111; - - static constexpr uint8_t data_max_positive_normal_mask = 0b0111; - static constexpr uint8_t data_max_negative_normal_mask = 0b1111; - - static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001; - static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001; - - static constexpr bool has_inf = false; - - using bitwise_type = uint8_t; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 2; - static constexpr int mant = 3; - static constexpr int bias = 1; - static constexpr uint32_t sr_shift = 12; - - static constexpr int unbiased_exp_min = 0; - static constexpr int unbiased_exp_max = 2; - static constexpr int biased_exp_min = 1; - static constexpr int biased_exp_max = 3; - - static constexpr uint8_t positive_zero_mask = 0b000000; - static constexpr uint8_t negative_zero_mask = 0b100000; - - static constexpr uint8_t set_sign_mask = 0b011111; - - static constexpr uint8_t data_max_positive_normal_mask = 0b011111; - static constexpr uint8_t data_max_negative_normal_mask = 0b111111; - - static constexpr uint8_t data_max_positive_subnormal_mask = 0b000111; - static constexpr uint8_t data_max_negative_subnormal_mask = 0b100111; - - static constexpr bool has_inf = false; - static constexpr bool has_nan = false; - static constexpr bool has_zero = true; - - using bitwise_type = uint8_t; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 3; - static constexpr int mant = 2; - static constexpr int bias = 3; - static constexpr uint32_t sr_shift = 11; - - static constexpr int unbiased_exp_min = -2; - static constexpr int unbiased_exp_max = 4; - static constexpr int biased_exp_min = 1; - static constexpr int biased_exp_max = 7; - - static constexpr uint8_t positive_zero_mask = 0b000000; - static constexpr uint8_t negative_zero_mask = 0b100000; - - static constexpr uint8_t set_sign_mask = 0b011111; - - static constexpr uint8_t data_max_positive_normal_mask = 0b011111; - static constexpr uint8_t data_max_negative_normal_mask = 0b111111; - - static constexpr uint8_t data_max_positive_subnormal_mask = 0b000011; - static constexpr uint8_t data_max_negative_subnormal_mask = 0b100011; - - static constexpr bool has_inf = false; - static constexpr bool has_nan = false; - static constexpr bool has_zero = true; - - using bitwise_type = uint8_t; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 8; - static constexpr int mant = 0; - static constexpr int bias = 127; - - static constexpr int unbiased_exp_min = -127; - static constexpr int unbiased_exp_max = 127; - static constexpr int biased_exp_min = 0; - static constexpr int biased_exp_max = 254; - - using bitwise_type = uint8_t; -}; } // namespace ck diff --git a/include/ck/utility/dtype_fp64.hpp b/include/ck/utility/dtype_fp64.hpp new file mode 100644 index 0000000000..3c63d083ad --- /dev/null +++ b/include/ck/utility/dtype_fp64.hpp @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +namespace ck { +// fp64 +using double2_t = typename vector_type::type; +using double4_t = typename vector_type::type; +} // namespace ck diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp new file mode 100644 index 0000000000..8f70962fa6 --- /dev/null +++ b/include/ck/utility/dtype_vector.hpp @@ -0,0 +1,2138 @@ +// SPDX-License-Identifier: MIT +// // // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck/utility/data_type.hpp" + +namespace ck { + +// vector_type +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type, N>; + +// vector_type_maker +// This is the right way to handle "vector of vectors": making a bigger vector instead +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct scalar_type> +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker, N0> +{ + using type = vector_type; +}; + +template +using vector_type_maker_t = typename vector_type_maker::type; + +template +__host__ __device__ constexpr auto make_vector_type(Number) +{ + return typename vector_type_maker::type{}; +} + +template +struct vector_type()>> +{ + using d1_t = T; + using type = d1_t; + + union + { + T d1_; + StaticallyIndexedArray d1x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); + + return data_.d1x1_; + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); + + return data_.d1x1_; + } +}; + +__device__ int static err = 0; +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + + using type = d2_t; + + union + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d3_t __attribute__((ext_vector_type(3))); + + using type = d3_t; + + union + { + d3_t d3_; + StaticallyIndexedArray d1x3_; + StaticallyIndexedArray d2x1_; + StaticallyIndexedArray d3x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x3_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else if constexpr(is_same::value) + { + return data_.d3x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x3_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else if constexpr(is_same::value) + { + return data_.d3x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + + using type = d4_t; + + union + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d5_t __attribute__((ext_vector_type(5))); + + using type = d5_t; + + union + { + d5_t d5_; + StaticallyIndexedArray d1x5_; + StaticallyIndexedArray d4x1_; + StaticallyIndexedArray d5x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x5_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else if constexpr(is_same::value) + { + return data_.d5x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x5_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else if constexpr(is_same::value) + { + return data_.d5x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d7_t __attribute__((ext_vector_type(7))); + + using type = d7_t; + + union + { + d7_t d7_; + StaticallyIndexedArray d1x7_; + StaticallyIndexedArray d2x3_; + StaticallyIndexedArray d4x1_; + StaticallyIndexedArray d7x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x7_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else if constexpr(is_same::value) + { + return data_.d7x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x7_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else if constexpr(is_same::value) + { + return data_.d7x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + + using type = d8_t; + + union + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d13_t __attribute__((ext_vector_type(13))); + + using type = d13_t; + + union + { + d13_t d13_; + StaticallyIndexedArray d1x13_; + StaticallyIndexedArray d4x3_; + StaticallyIndexedArray d8x1_; + StaticallyIndexedArray d13x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x13_; + } + else if constexpr(is_same::value) + { + return data_.d4x3_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else if constexpr(is_same::value) + { + return data_.d13x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x13_; + } + else if constexpr(is_same::value) + { + return data_.d4x3_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else if constexpr(is_same::value) + { + return data_.d13x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + + using type = d16_t; + + union + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + + using type = d32_t; + + union + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_ = {d32_t{0}}; + + __attribute__((host)) __attribute__((device)) constexpr vector_type() {} + + __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } + + // __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + // __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + + using type = d64_t; + + union + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + + using type = d128_t; + + union + { + d128_t d128_; + StaticallyIndexedArray d1x128_; + StaticallyIndexedArray d2x64_; + StaticallyIndexedArray d4x32_; + StaticallyIndexedArray d8x16_; + StaticallyIndexedArray d16x8_; + StaticallyIndexedArray d32x4_; + StaticallyIndexedArray d64x2_; + StaticallyIndexedArray d128x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + typedef T d256_t __attribute__((ext_vector_type(256))); + + using type = d256_t; + + union + { + d256_t d256_; + StaticallyIndexedArray d1x256_; + StaticallyIndexedArray d2x128_; + StaticallyIndexedArray d4x64_; + StaticallyIndexedArray d8x32_; + StaticallyIndexedArray d16x16_; + StaticallyIndexedArray d32x8_; + StaticallyIndexedArray d64x4_; + StaticallyIndexedArray d128x2_; + StaticallyIndexedArray d256x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + else + { + return err; + } + } +}; + +template +struct non_native_vector_base; + +template +struct nnvb_data_t_selector +{ + using type = unsigned _BitInt(8 * sizeof(T)); +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f8_ocp_t::data_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf8_ocp_t::data_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = e8m0_bexp_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f6x16_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f6x32_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x16_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x32_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = pk_i4_t::type; +}; + +template +struct non_native_vector_base< + T, + N, + ck::enable_if_t> +{ + using data_t = typename nnvb_data_t_selector::type; // select data_t based on the size of T + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + using data_v = data_t __attribute__((ext_vector_type(N))); + using type = non_native_vector_base; + + union alignas(next_pow2(N * sizeof(T))) + { + data_v dN; // storage vector; + StaticallyIndexedArray dxN; + StaticallyIndexedArray dTxN; + StaticallyIndexedArray dNx1; + } data_; + + __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {} + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) + { + } + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const + { + if constexpr(N == 1) + { + return data_.dxN[Number<0>{}]; + } + else + { + return data_.dxN; // XXX this should cause an error + } + } + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) + { + return data_.dTxN[Number<0>{}]; + } + else + { + return data_.dTxN; // XXX this should cause an error + } + } + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else if constexpr(is_same_v) + { + return data_.dNx1; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else if constexpr(is_same_v) + { + return data_.dNx1; + } + else + { + return err; + } + } +}; + +// implementation for f6x16 and f6x32 +template +struct non_native_vector_base> +{ + using data_t = + typename nnvb_data_t_selector::type; // select data_t based on declared base type + using element_t = typename T::element_type; // select element_t based on declared element type + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + static constexpr size_t size_factor = + sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6 + using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); + using type = non_native_vector_base; + + union alignas(next_pow2(N * sizeof(T))) + { + data_v dN; // storage vector; + StaticallyIndexedArray dxN; + StaticallyIndexedArray dTxN; + StaticallyIndexedArray dNx1; + } data_; + + __host__ __device__ constexpr non_native_vector_base(data_t a) + : data_{data_v(a.At(Number<0>{}))} + { + } + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) + { + } + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const + { + if constexpr(N == 1) + { + return data_.dxN[Number<0>{}]; + } + else + { + return data_.dxN; // XXX this should cause an error + } + } + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) + { + return data_.dTxN[Number<0>{}]; + } + else + { + return data_.dTxN; // XXX this should cause an error + } + } +}; + +template +struct scalar_type> +{ + using type = typename non_native_vector_base::data_t; + static constexpr index_t vector_size = N; +}; + +// non-native vector_type implementation +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using type = d1_nnv_t; + + union alignas(next_pow2(1 * sizeof(T))) + { + d1_t d1_; + StaticallyIndexedArray d1x1_; + d1_nnv_t d1_nnv_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{d1_t{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + + using type = d2_t; + + union alignas(next_pow2(2 * sizeof(T))) + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + + using type = d4_t; + + union alignas(next_pow2(4 * sizeof(T))) + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + + using type = d8_t; + + union alignas(next_pow2(8 * sizeof(T))) + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + + using type = d16_t; + + union alignas(next_pow2(16 * sizeof(T))) + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + using d32_t = non_native_vector_base; + + using type = d32_t; + + union alignas(next_pow2(32 * sizeof(T))) + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + using d32_t = non_native_vector_base; + using d64_t = non_native_vector_base; + + using type = d64_t; + + union alignas(next_pow2(64 * sizeof(T))) + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + +using int64_t = long; + +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; +using float16_t = typename vector_type::type; +using float32_t = typename vector_type::type; +using float64_t = typename vector_type::type; + +// fp16 +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; +using half16_t = typename vector_type::type; +using half32_t = typename vector_type::type; + +// bfp16 +using bhalf2_t = typename vector_type::type; +using bhalf4_t = typename vector_type::type; +using bhalf8_t = typename vector_type::type; +using bhalf16_t = typename vector_type::type; +using bhalf32_t = typename vector_type::type; + +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; +using int32x16_t = typename vector_type::type; +using int32x32_t = typename vector_type::type; +using int32x64_t = typename vector_type::type; + +// i8 +using int8x2_t = typename vector_type::type; +using int8x4_t = typename vector_type::type; +using int8x8_t = typename vector_type::type; +using int8x16_t = typename vector_type::type; +using int8x32_t = typename vector_type::type; +using int8x64_t = typename vector_type::type; + +// f8 +using f8x2_fnuz_t = typename vector_type::type; +using f8x4_fnuz_t = typename vector_type::type; +using f8x8_fnuz_t = typename vector_type::type; +using f8x16_fnuz_t = typename vector_type::type; +using f8x32_fnuz_t = typename vector_type::type; +using f8x64_fnuz_t = typename vector_type::type; + +// bf8 +using bf8x2_fnuz_t = typename vector_type::type; +using bf8x4_fnuz_t = typename vector_type::type; +using bf8x8_fnuz_t = typename vector_type::type; +using bf8x16_fnuz_t = typename vector_type::type; +using bf8x32_fnuz_t = typename vector_type::type; +using bf8x64_fnuz_t = typename vector_type::type; + +// f8 +using f8x2_ocp_t = typename vector_type::type; +using f8x4_ocp_t = typename vector_type::type; +using f8x8_ocp_t = typename vector_type::type; +using f8x16_ocp_t = typename vector_type::type; +using f8x32_ocp_t = typename vector_type::type; +using f8x64_ocp_t = typename vector_type::type; + +// bf8 +using bf8x2_ocp_t = typename vector_type::type; +using bf8x4_ocp_t = typename vector_type::type; +using bf8x8_ocp_t = typename vector_type::type; +using bf8x16_ocp_t = typename vector_type::type; +using bf8x32_ocp_t = typename vector_type::type; +using bf8x64_ocp_t = typename vector_type::type; + +#if CK_FP8_TYPE_OCP +// f8 +using f8x2_t = f8x2_ocp_t; +using f8x4_t = f8x4_ocp_t; +using f8x8_t = f8x8_ocp_t; +using f8x16_t = f8x16_ocp_t; +using f8x32_t = f8x32_ocp_t; +using f8x64_t = f8x64_ocp_t; + +// bf8 +using bf8x2_t = bf8x2_ocp_t; +using bf8x4_t = bf8x4_ocp_t; +using bf8x8_t = bf8x8_ocp_t; +using bf8x16_t = bf8x16_ocp_t; +using bf8x32_t = bf8x32_ocp_t; +using bf8x64_t = bf8x64_ocp_t; +#elif CK_FP8_TYPE_FNUZ +// f8 +using f8x2_t = f8x2_fnuz_t; +using f8x4_t = f8x4_fnuz_t; +using f8x8_t = f8x8_fnuz_t; +using f8x16_t = f8x16_fnuz_t; +using f8x32_t = f8x32_fnuz_t; +using f8x64_t = f8x64_fnuz_t; + +// bf8 +using bf8x2_t = bf8x2_fnuz_t; +using bf8x4_t = bf8x4_fnuz_t; +using bf8x8_t = bf8x8_fnuz_t; +using bf8x16_t = bf8x16_fnuz_t; +using bf8x32_t = bf8x32_fnuz_t; +using bf8x64_t = bf8x64_fnuz_t; +#endif + +// u8 +using uint8x2_t = typename vector_type::type; +using uint8x4_t = typename vector_type::type; +using uint8x8_t = typename vector_type::type; +using uint8x16_t = typename vector_type::type; +using uint8x32_t = typename vector_type::type; +using uint8x64_t = typename vector_type::type; + +// f4 +using f4x2_t = typename vector_type::type; +using f4x4_t = typename vector_type::type; +using f4x8_t = typename vector_type::type; +using f4x16_t = typename vector_type::type; +using f4x32_t = typename vector_type::type; +using f4x64_t = typename vector_type::type; + +// f6 +using f6x16_t = typename vector_type::type; +using f6x32_t = typename vector_type::type; + +// bf6 +using bf6x16_t = typename vector_type::type; +using bf6x32_t = typename vector_type::type; + +// pack int4 +using pk_i4x2_t = typename vector_type::type; +using pk_i4x4_t = typename vector_type::type; +using pk_i4x8_t = typename vector_type::type; + +} // namespace ck diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 2533073225..799683ae65 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -3,7 +3,7 @@ #pragma once -#include "ck/utility/data_type.hpp" +#include "ck/utility/numeric_utils.hpp" namespace ck { diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index 98f40a4363..ab9cc4199c 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -3,6 +3,7 @@ #pragma once #include "data_type.hpp" +#include "dtype_fp64.hpp" namespace ck { diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index 05ae9093e2..7b079c541c 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck/ck.hpp" -#include "data_type.hpp" +#include "numeric_limits.hpp" #include "integral_constant.hpp" #include "number.hpp" #include "type.hpp" diff --git a/include/ck/utility/mxf4_utils.hpp b/include/ck/utility/mxf4_utils.hpp index 757d3914e3..72a0bb919c 100644 --- a/include/ck/utility/mxf4_utils.hpp +++ b/include/ck/utility/mxf4_utils.hpp @@ -4,7 +4,7 @@ #ifndef CK_CODE_GEN_RTC #pragma once -#include "ck/utility/data_type.hpp" +#include "ck/utility/numeric_limits.hpp" #include "ck/utility/mxfp_utils.hpp" namespace ck::utils { diff --git a/include/ck/utility/mxf6_utils.hpp b/include/ck/utility/mxf6_utils.hpp index 00b4f8e5d4..cf68188b3e 100644 --- a/include/ck/utility/mxf6_utils.hpp +++ b/include/ck/utility/mxf6_utils.hpp @@ -4,7 +4,7 @@ #ifndef CK_CODE_GEN_RTC #pragma once -#include "ck/utility/data_type.hpp" +#include "ck/utility/numeric_limits.hpp" #include "ck/utility/mxfp_utils.hpp" namespace ck::utils { diff --git a/include/ck/utility/mxf8_utils.hpp b/include/ck/utility/mxf8_utils.hpp index 2dbf997f6a..b7b98c6455 100644 --- a/include/ck/utility/mxf8_utils.hpp +++ b/include/ck/utility/mxf8_utils.hpp @@ -1,4 +1,7 @@ -#include "ck/utility/data_type.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/numeric_limits.hpp" #include "ck/utility/mxfp_utils.hpp" #if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ diff --git a/include/ck/utility/numeric_limits.hpp b/include/ck/utility/numeric_limits.hpp new file mode 100644 index 0000000000..e59b7eceaf --- /dev/null +++ b/include/ck/utility/numeric_limits.hpp @@ -0,0 +1,555 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck/utility/data_type.hpp" + +namespace ck { + +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) +template +struct NumericLimits; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; } + + __host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; } + + __host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; } + + __host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int32_t QuietNaN() { return 0; } +}; +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; } + + __host__ __device__ static constexpr int16_t Min() noexcept { return -32768; } + + __host__ __device__ static constexpr int16_t Max() noexcept { return 32767; } + + __host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int16_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; } + + __host__ __device__ static constexpr int8_t Min() noexcept { return -128; } + + __host__ __device__ static constexpr int8_t Max() noexcept { return 127; } + + __host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int8_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t Min() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; } + + __host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t Min() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; } + + __host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned int binary_min = 0x00800000; + static constexpr unsigned int binary_max = 0x7F7FFFFF; + static constexpr unsigned int binary_lowest = 0xFF7FFFFF; + static constexpr unsigned int binary_qnan = 0xFFC00001; + static constexpr unsigned int binary_inf = 0x7F800000; + + __host__ __device__ static constexpr float Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr float Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr float Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr float QuietNaN() { return bit_cast(binary_qnan); } + + __host__ __device__ static constexpr float Infinity() { return bit_cast(binary_inf); } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned short binary_min = 0x0400; + static constexpr unsigned short binary_max = 0x7BFF; + static constexpr unsigned short binary_lowest = 0xFBFF; + static constexpr unsigned short binary_qnan = 0x7FFF; + + __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int4_t Min() { return int4_t(-8); } + + __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } + + __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); } +}; +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 8 + static constexpr uint8_t binary_min = 0x08; // 0b00001000 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 7 + // static constexpr uint8_t binary_min = 0x08; // 0b00001000 + // static constexpr uint8_t binary_max = 0x77; // 0b01110111 + // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 + + __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } + + __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } + + __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } + + __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 16 + static constexpr uint8_t binary_min = 0x04; // 0b00000100 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 15 + // static constexpr uint8_t binary_min = 0x04; // 0b00000100 + // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 + // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= + + __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } + + __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } + + __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } + + __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 + static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 + static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 + static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 + + __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr f8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr f8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 + static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 + static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 + static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 + + __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr bf8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr bf8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 + static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 + static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 + static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 + static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 + + static constexpr float data_max_normal_number = 6; + static constexpr float data_min_subnormal_number = 0.5; + + __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } + __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } + __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } + __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } + __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 + + static constexpr float data_max_normal_number = 7.5; + static constexpr float data_min_subnormal_number = 0.125; + + __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Lowest() + { + return f6_t(binary_lowest_normal & 0b111111); + } + __host__ __device__ static constexpr f6_t MinSubnorm() + { + return f6_t(binary_min_subnorm & 0b111111); + } + __host__ __device__ static constexpr f6_t MaxSubnorm() + { + return f6_t(binary_max_subnorm & 0b111111); + } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 + + static constexpr float data_max_normal_number = 28; + static constexpr float data_min_subnormal_number = 0.0625; + + __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } + __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } + __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } + __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } + __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +#else +template +struct NumericLimits +{ + __host__ __device__ static constexpr T Min() { return std::numeric_limits::min(); } + __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } + __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } + __host__ __device__ static constexpr T QuietNaN() + { + return std::numeric_limits::quiet_NaN(); + } + __host__ __device__ static constexpr T Infinity() { return std::numeric_limits::infinity(); } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned short binary_min = 0x0400; + static constexpr unsigned short binary_max = 0x7BFF; + static constexpr unsigned short binary_lowest = 0xFBFF; + static constexpr unsigned short binary_qnan = 0x7FFF; + + __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int4_t Min() { return int4_t(-8); } + + __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } + + __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); } +}; +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 8 + static constexpr uint8_t binary_min = 0x08; // 0b00001000 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 7 + // static constexpr uint8_t binary_min = 0x08; // 0b00001000 + // static constexpr uint8_t binary_max = 0x77; // 0b01110111 + // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 + + __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } + + __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } + + __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } + + __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 16 + static constexpr uint8_t binary_min = 0x04; // 0b00000100 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 15 + // static constexpr uint8_t binary_min = 0x04; // 0b00000100 + // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 + // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= + + __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } + + __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } + + __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } + + __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 + static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 + static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 + static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 + + __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr f8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr f8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 + static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 + static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 + static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 + + __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr bf8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr bf8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 + static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 + static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 + static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 + static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 + + static constexpr float data_max_normal_number = 6; + static constexpr float data_min_subnormal_number = 0.5; + + __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } + __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } + __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } + __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } + __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 + + static constexpr float data_max_normal_number = 7.5; + static constexpr float data_min_subnormal_number = 0.125; + + __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Lowest() + { + return f6_t(binary_lowest_normal & 0b111111); + } + __host__ __device__ static constexpr f6_t MinSubnorm() + { + return f6_t(binary_min_subnorm & 0b111111); + } + __host__ __device__ static constexpr f6_t MaxSubnorm() + { + return f6_t(binary_max_subnorm & 0b111111); + } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 + + static constexpr float data_max_normal_number = 28; + static constexpr float data_min_subnormal_number = 0.0625; + + __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } + __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } + __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } + __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } + __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +#endif + +template <> +struct NumericLimits +{ + static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 + static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 + static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 + static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 + static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 + static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 + static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 + static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 + + __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } + __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } + __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_135() + { + return e8m0_bexp_t(binary_135); + } + __host__ __device__ static constexpr e8m0_bexp_t Binary_142() + { + return e8m0_bexp_t(binary_142); + } +}; + +} // namespace ck diff --git a/include/ck/utility/numeric_utils.hpp b/include/ck/utility/numeric_utils.hpp new file mode 100644 index 0000000000..726f667518 --- /dev/null +++ b/include/ck/utility/numeric_utils.hpp @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: MIT +// // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck/utility/data_type.hpp" + +namespace ck { + +template +struct NumericUtils +{ +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 0; + static constexpr int bias = 127; + + static constexpr int unbiased_exp_min = -127; + static constexpr int unbiased_exp_max = 127; + static constexpr int biased_exp_min = 0; + static constexpr int biased_exp_max = 254; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 23; + static constexpr int bias = 127; + static constexpr uint32_t nan_mask = 0x7F800000; + static constexpr uint32_t head_mask = 0xFF800000; + static constexpr uint32_t mant_mask = 0x7FFFFF; + static constexpr uint32_t exp_mask = 0xFF; + static constexpr uint32_t Inf = 0x7F800000; + static constexpr uint32_t NegInf = 0xFF800000; + static constexpr uint32_t NaN = 0x7F800001; + static constexpr uint32_t Neg0 = 0x80000000; + static constexpr bool has_inf = true; + using bitwise_type = uint32_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 10; + static constexpr int bias = 15; + static constexpr uint16_t nan_mask = 0x7C00; + static constexpr uint16_t head_mask = 0xFC00; + static constexpr uint16_t mant_mask = 0x3FF; + static constexpr uint16_t exp_mask = 0x1F; + static constexpr uint32_t Inf = 0x7C00; + static constexpr uint32_t NegInf = 0xFC00; + static constexpr uint32_t NaN = 0x7C01; + static constexpr uint32_t Neg0 = 0x8000; + static constexpr bool has_inf = true; + using bitwise_type = uint16_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 7; + static constexpr int bias = 128; // negative zero nan mode + // static constexpr int bias = 127; // ieee mode +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 4; + static constexpr int mant = 3; + static constexpr int bias = 8; // negative zero nan mode + // static constexpr int bias = 7; // ieee mode + static constexpr bool has_inf = false; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 2; + static constexpr int bias = 16; // negative zero nan mode + // static constexpr int bias = 15; // ieee mode + static constexpr bool has_inf = false; +}; +template <> +struct NumericUtils +{ + static constexpr int exp = 4; + static constexpr int mant = 3; + static constexpr int bias = 7; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 2; + static constexpr int bias = 15; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 2; + static constexpr int mant = 1; + static constexpr int bias = 1; + static constexpr uint32_t sr_shift = 10; + + static constexpr int unbiased_exp_min = 0; + static constexpr int unbiased_exp_max = 2; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 3; + + static constexpr uint8_t positive_zero_mask = 0b0000; + static constexpr uint8_t negative_zero_mask = 0b1000; + + static constexpr uint8_t one_mask = 0b0010; + static constexpr uint8_t set_sign_mask = 0b0111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b0111; + static constexpr uint8_t data_max_negative_normal_mask = 0b1111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001; + + static constexpr bool has_inf = false; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 2; + static constexpr int mant = 3; + static constexpr int bias = 1; + static constexpr uint32_t sr_shift = 12; + + static constexpr int unbiased_exp_min = 0; + static constexpr int unbiased_exp_max = 2; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 3; + + static constexpr uint8_t positive_zero_mask = 0b000000; + static constexpr uint8_t negative_zero_mask = 0b100000; + + static constexpr uint8_t set_sign_mask = 0b011111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b011111; + static constexpr uint8_t data_max_negative_normal_mask = 0b111111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b000111; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b100111; + + static constexpr bool has_inf = false; + static constexpr bool has_nan = false; + static constexpr bool has_zero = true; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 3; + static constexpr int mant = 2; + static constexpr int bias = 3; + static constexpr uint32_t sr_shift = 11; + + static constexpr int unbiased_exp_min = -2; + static constexpr int unbiased_exp_max = 4; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 7; + + static constexpr uint8_t positive_zero_mask = 0b000000; + static constexpr uint8_t negative_zero_mask = 0b100000; + + static constexpr uint8_t set_sign_mask = 0b011111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b011111; + static constexpr uint8_t data_max_negative_normal_mask = 0b111111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b000011; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b100011; + + static constexpr bool has_inf = false; + static constexpr bool has_nan = false; + static constexpr bool has_zero = true; + + using bitwise_type = uint8_t; +}; +} // namespace ck diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 25dae4e335..99935a6d8d 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -256,6 +256,18 @@ struct arithmetic_sequence_gen using type = typename conditional::type; }; +template +struct arithmetic_sequence_gen<0, IEnd, 1> +{ + template + struct WrapSequence + { + using type = Sequence; + }; + // https://reviews.llvm.org/D13786 + using type = typename __make_integer_seq::type; +}; + // uniform sequence template struct uniform_sequence_gen diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index 376027ec98..d159787387 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -38,7 +38,6 @@ make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byt return [=](const stream_config& s) { kernel<<>>(args...); - return hipPeekAtLastError() == hipSuccess; }; } @@ -46,7 +45,7 @@ template CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... callables) { // abort the sequence in case of intermediate error - if(!(callables(sc) && ...)) + if(!((static_cast(callables(sc)), hipPeekAtLastError() == hipSuccess) && ...)) { HIP_CHECK_ERROR(hipGetLastError()); } diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp index 475d7014dd..070168b51d 100644 --- a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -35,11 +35,13 @@ template void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, const HostTensor& gamma_n, HostTensor& y_m_n, HostTensor& invRms_m, + HostTensor& unquant_y_m_n, ComputeDataType epsilon, Epilogue epilogue_functor = {}) { @@ -69,7 +71,14 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, acc(m, n) = x * divisor * gamma; } - epilogue_functor(m, y_m_n, acc); + if constexpr(!std::is_same_v) + { + epilogue_functor(m, unquant_y_m_n, y_m_n, acc); + } + else + { + epilogue_functor(m, y_m_n, acc); + } }; make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])( diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp old mode 100644 new mode 100755 index 8592f93e0f..b422a0a896 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -18,6 +18,7 @@ template <> struct typeToStr { static constexpr const char * name = "bf1 template <> struct typeToStr { static constexpr const char * name = "fp8"; }; template <> struct typeToStr { static constexpr const char * name = "bf8"; }; template <> struct typeToStr { static constexpr const char * name = "int8"; }; +template <> struct typeToStr { static constexpr const char * name = "pk_int4"; }; // clang-format on template diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 9d2ed407c9..12e53e13e6 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" +#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp new file mode 100644 index 0000000000..6c5a2ac149 --- /dev/null +++ b/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "default_2d_epilogue.hpp" +#include "dynamic_quant_epilogue.hpp" + +namespace ck_tile { + +// User can reuse DynamicQuantEpilogueTraits with this epilogue +template +using Default2DAndDynamicQuantEpilogueTraits = + DynamicQuantEpilogueTraits; + +// This epilogue just store out a M*N matrix, row major +template +struct Default2DAndDynamicQuantEpilogueProblem +{ + using AccDataType = remove_cvref_t; + using SmoothScaleDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using UnquantYDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; // can consum generic 2d shape + using Traits = remove_cvref_t; +}; + +template +struct Default2DAndDynamicQuantEpilogue +{ + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using UnquantYDataType = remove_cvref_t; + + static constexpr bool kPadM = Problem::Traits::kPadM; + static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr bool UseRawStore = Problem::Traits::UseRawStore; + + using Default2DProblem = + Default2DEpilogueProblem; + using Default2D = Default2DEpilogue; + using DynamicQuant = DynamicQuantEpilogue; + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(Default2D::GetSmemSize(), DynamicQuant::GetSmemSize()); + } + + template + CK_TILE_DEVICE auto operator()(ODramWindowTmpD& o_direct_dram_window_tmp, + ODramWindowTmpQ& o_quant_dram_window_tmp, + const SmoothScaleWindow& sm_scale_window_, + YScaleWindow& y_scale_window, + const OAccTile& o_acc_tile, + void* smem) + { + Default2D{}(o_direct_dram_window_tmp, o_acc_tile, smem); + DynamicQuant{}(o_quant_dram_window_tmp, sm_scale_window_, y_scale_window, o_acc_tile, smem); + } + + template + CK_TILE_DEVICE auto operator()(ODramWindowTmpD& o_direct_dram_window_tmp, + ODramWindowTmpQ& o_quant_dram_window_tmp, + YScaleWindow& y_scale_window, + const OAccTile& o_acc_tile, + void* smem) + { + Default2D{}(o_direct_dram_window_tmp, o_acc_tile, smem); + DynamicQuant{}(o_quant_dram_window_tmp, y_scale_window, o_acc_tile, smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 23174528e7..35b2f02e8a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -100,10 +100,10 @@ struct FmhaBwdDQDKDVKernel "r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" + - ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + - (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + - (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "" ); + ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "_npad" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) + + (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ); #undef _SS_ #undef _TS_ // clang-format on @@ -1620,7 +1620,7 @@ struct FmhaBwdOGradDotOKernel return _SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s::name) + "_" + (kIsGroupMode ? "group" : "batch") + "_" + - ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn); + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn); #undef _SS_ #undef _TS_ // clang-format on @@ -1875,8 +1875,8 @@ struct FmhaBwdConvertQGradKernel return n.empty() ? n : std::string("p") + n; }(); return _SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + (kIsDeterministic ? "_deterministic" : "") + "_" + - ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn); + "_" + (kIsGroupMode ? "group" : "batch") + "_" + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) + + (kIsDeterministic ? "_deterministic" : "_ndeterministic") ; #undef _SS_ #undef _TS_ // clang-format on diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index c671463252..a578f0c2f4 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -93,9 +93,9 @@ struct FmhaFwdKernel "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + - "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + - (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); #undef _SS_ #undef _TS_ // clang-format on diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index a342a91f10..99ee912db9 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -54,9 +54,9 @@ struct FmhaFwdSplitKVCombineKernel "b" + _TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + - (pn.empty() ? "" : "_" + pn) + - (kStoreLSE ? "_lse" : "" ) + - (kDoFp8StaticQuant ? "_squant" : "" ); + (pn.empty() ? "_npad" : "_" + pn) + + (kStoreLSE ? "_lse" : "_nlse" ) + + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); #undef _SS_ #undef _TS_ // clang-format on diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 14d0596287..143abe8048 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -94,9 +94,10 @@ struct FmhaFwdSplitKVKernel "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + - "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + - (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" ); + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + + (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ); #undef _SS_ #undef _TS_ // clang-format on diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 323c682f2c..dfb6bfae58 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -46,7 +46,7 @@ struct BatchedGemmKernel : public GemmKernel; - using GemmKernelArgs = typename Base::GemmKernelArgs; + using GemmKernelArgs = typename ck_tile::GemmKernelArgs; using ADataType = typename Base::ADataType; using BDataType = typename Base::BDataType; @@ -65,7 +65,7 @@ struct BatchedGemmKernel : public GemmKernel, - concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock), + concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); // clang-format on diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 503a92b863..9435855d0a 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -56,6 +56,20 @@ struct GemmHostArgs : public GemmProblem index_t k_batch; }; +struct GemmKernelArgs +{ + const void* a_ptr; + const void* b_ptr; + void* c_ptr; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t k_batch; +}; + template struct GemmKernel { @@ -90,20 +104,6 @@ struct GemmKernel CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - struct GemmKernelArgs - { - const void* a_ptr; - const void* b_ptr; - void* c_ptr; - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - index_t stride_C; - index_t k_batch; - }; - CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) { return GemmKernelArgs{hostArgs.a_ptr, diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 751e7c0e1a..5577cb083a 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -11,24 +11,17 @@ namespace ck_tile { -struct GroupedGemmHostArgs : public ck_tile::GemmHostArgs +struct GemmTransKernelArg { - CK_TILE_HOST GroupedGemmHostArgs() noexcept = default; - CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_, - const void* b_ptr_, - void* c_ptr_, - ck_tile::index_t M_, - ck_tile::index_t N_, - ck_tile::index_t K_, - ck_tile::index_t stride_A_, - ck_tile::index_t stride_B_, - ck_tile::index_t stride_C_) - : GemmHostArgs(a_ptr_, b_ptr_, c_ptr_, KBatch, M_, N_, K_, stride_A_, stride_B_, stride_C_) + GemmKernelArgs group_karg; + ck_tile::index_t block_start; + ck_tile::index_t block_end; + + GemmTransKernelArg() = default; + GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end) + : group_karg{karg}, block_start{bl_start}, block_end{bl_end} { } - - private: - static constexpr index_t KBatch = 1; }; template @@ -47,36 +40,22 @@ struct GroupedGemmKernel : public GemmKernel; using Base = GemmKernel; - using GemmKernelArgs = typename Base::GemmKernelArgs; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; - struct GemmTransKernelArg - { - GemmKernelArgs group_karg; - ck_tile::index_t block_start; - ck_tile::index_t block_end; - - GemmTransKernelArg() = default; - GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end) - : group_karg{karg}, block_start{bl_start}, block_end{bl_end} - { - } - }; - [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off using P_ = GemmPipeline; return concat('_', "gemm_grouped", gemm_prec_str, - concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock), + concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); // clang-format on } - __host__ static auto GetWorkSpaceSize(const std::vector& gemm_descs) + __host__ static auto GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); @@ -84,7 +63,7 @@ struct GroupedGemmKernel : public GemmKernel dim3 { return dim3(KernelBlockSize); } - __host__ static constexpr auto GridSize(const std::vector& gemm_descs) + __host__ static constexpr auto GridSize(const std::vector& gemm_descs) { index_t grid_size = 0; for(const auto& it_desc : gemm_descs) @@ -95,7 +74,7 @@ struct GroupedGemmKernel : public GemmKernel& gemm_descs) + CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) -> std::vector { std::vector gemm_kernel_args_; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 2a10389ce6..217408fffa 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -12,7 +12,7 @@ namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template +template struct GemmPipelineAGmemBGmemCRegV1 { using ADataType = remove_cvref_t; @@ -182,11 +182,11 @@ struct GemmPipelineAGmemBGmemCRegV1 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - if constexpr(std::is_same_v) + if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegBlockDistribution()); - shuffle_tile(a_shuffle_tmp, a_block_tile); + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); store_tile(a_copy_lds_window, a_block_tile_tmp); } @@ -196,11 +196,11 @@ struct GemmPipelineAGmemBGmemCRegV1 } // LDS write 0 - if constexpr(std::is_same_v) + if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegBlockDistribution()); - shuffle_tile(b_shuffle_tmp, b_block_tile); + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp); store_tile(b_copy_lds_window, b_block_tile_tmp); } @@ -229,15 +229,26 @@ struct GemmPipelineAGmemBGmemCRegV1 move_tile_window(b_copy_dram_window, {0, kKPerBlock}); // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp_loop, a_block_tile); + store_tile(a_copy_lds_window, + tile_elementwise_in(a_element_func, a_shuffle_tmp_loop)); + } + else + { + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + } // LDS write i + 1 - if constexpr(std::is_same_v) + if constexpr(is_b_row_major) { auto b_shuffle_tmp_loop = make_static_distributed_tensor( - Policy::template MakeShuffledBRegBlockDistribution()); - shuffle_tile(b_shuffle_tmp_loop, b_block_tile); + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp_loop, b_block_tile); store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_shuffle_tmp_loop)); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp old mode 100644 new mode 100755 index c7115c8eb4..6bb14af9e6 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -129,7 +129,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr index_t KPack = GetSmemPackA(); static_assert(KPack % K3 == 0); constexpr index_t K2 = KPack / K3; - if constexpr(get_warp_size() % (K2 * M0)) + if constexpr(get_warp_size() >= (K2 * M0)) { constexpr index_t K1 = get_warp_size() / (K2 * M0); constexpr index_t K0 = BlockSize / get_warp_size(); @@ -219,7 +219,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr index_t KPack = GetSmemPackB(); static_assert(KPack % K3 == 0); constexpr index_t K2 = KPack / K3; - if constexpr(get_warp_size() % (K2 * N0) == 0) + if constexpr(get_warp_size() >= (K2 * N0)) { constexpr index_t K1 = get_warp_size() / (K2 * N0); constexpr index_t K0 = BlockSize / get_warp_size(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index f5b3523f60..c504a51ad0 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -362,7 +362,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( a_lds_block_desc_permuted, make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), + make_tuple(number{}, number{})), make_pass_through_transform(number{}), make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), @@ -374,7 +374,7 @@ struct UniversalGemmPipelineAgBgCrPolicy make_tuple(number{}, number{})), make_merge_transform_v3_division_mod( make_tuple(number{}, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return a_lds_block_desc; @@ -421,7 +421,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0, number{})), + make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), make_pass_through_transform(number{}), make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), @@ -432,7 +432,7 @@ struct UniversalGemmPipelineAgBgCrPolicy make_tuple(make_merge_transform_v3_division_mod( make_tuple(number{}, number{})), make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; } diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index b0b0c194ad..73cdd084c6 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -182,9 +182,16 @@ struct Layernorm2dFwdPipelineTwoPass ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); - move_tile_window(x_bias_window, {-Block_N}); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) + { + move_tile_window(y_residual_window, {0, -Block_N}); + } + else + { + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(x_residual_window, {0, -Block_N}); + move_tile_window(x_bias_window, {-Block_N}); + } move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(beta_window, {stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window}); @@ -192,28 +199,43 @@ struct Layernorm2dFwdPipelineTwoPass // layernorm computation for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - auto x = load_tile(x_window); - auto x_resi = load_tile(x_residual_window); - const auto x_bias = load_tile(x_bias_window); - auto acc = cast_tile(x); + auto acc = make_static_distributed_tensor( + decltype(load_tile(x_window))::get_tile_distribution()); - if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS) + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) { - sweep_tile(x, [&](auto idx) { - // compute x = bias + x - constexpr auto j_idx = make_tuple(idx[number<1>{}]); - acc(idx) = type_convert(x_bias[j_idx]) + acc(idx); - }); + acc = cast_tile(load_tile(y_residual_window)); + move_tile_window(y_residual_window, {0, -Block_N}); + } + else + { + acc = cast_tile(load_tile(x_window)); + move_tile_window(x_window, {0, -Block_N}); + + if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS) + { + const auto x_bias = load_tile(x_bias_window); + move_tile_window(x_bias_window, {-Block_N}); + + sweep_tile(acc, [&](auto idx) { + // compute x = bias + x + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + acc(idx) = type_convert(x_bias[j_idx]) + acc(idx); + }); + } + + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) + { + auto x_resi = load_tile(x_residual_window); + move_tile_window(x_residual_window, {0, -Block_N}); + + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + acc(idx) = type_convert(x_resi(idx)) + acc(idx); + }); + } } - if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || - kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) - { - sweep_tile(x_resi, [&](auto idx) { - // compute x = x_resi + x - acc(idx) = type_convert(x_resi(idx)) + acc(idx); - }); - } // load gamma/beta (TODO: support no gamma/beta?) const auto gamma = load_tile(gamma_window); const auto beta = load_tile(beta_window); @@ -235,9 +257,6 @@ struct Layernorm2dFwdPipelineTwoPass static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT); Epilogue{}(y_window, ln); - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); - move_tile_window(x_bias_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N}); move_tile_window(beta_window, {-Block_N}); move_tile_window(y_window, {0, -Block_N}); diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index 88c8084de6..f0251177d4 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -21,6 +21,7 @@ struct Rmsnorm2dFwdHostArgs void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used + void* p_y_unquant; // [m, n], output result before quant, nullptr if not used float epsilon; @@ -47,13 +48,15 @@ struct Rmsnorm2dFwd using InvRmsDataType = remove_cvref_t; using SmoothScaleDataType = remove_cvref_t; using YScaleDataType = remove_cvref_t; + using UnquantYDataType = remove_cvref_t; // for simplicity, shortcut input/output type is same as X using XResidualDataType = XDataType; using YResidualDataType = XDataType; - static constexpr bool kHasGamma = !std::is_same_v; - static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms; + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms; + static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant; static constexpr index_t Block_M = Problem::BlockShape::Block_M; static constexpr index_t Block_N = Problem::BlockShape::Block_N; @@ -81,6 +84,7 @@ struct Rmsnorm2dFwd void* p_y_residual; void* p_y_scale; void* p_invRms; + void* p_y_unquant; float epsilon; @@ -103,6 +107,7 @@ struct Rmsnorm2dFwd hargs.p_y_residual, hargs.p_y_scale, hargs.p_invRms, + hargs.p_y_unquant, hargs.epsilon, hargs.m, hargs.n, @@ -323,6 +328,30 @@ struct Rmsnorm2dFwd } }(); + auto unquant_y_window = [&]() { + if constexpr((kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT || + kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) && + kSaveUnquant) + { + auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_y_unquant), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.y_stride, 1), + number{}, + number<1>{}); + + auto tmp2_ = pad_tensor_view(tmp_, + make_tuple(number{}, number{}), + sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + } + else + { + return make_null_tile_window(make_tuple(number{}, number{})); + } + }(); + __shared__ char smem[GetSmemSize()]; Pipeline{}(x_window, @@ -333,6 +362,7 @@ struct Rmsnorm2dFwd inv_rms_window, sm_scale_window, y_scale_window, + unquant_y_window, static_cast(kargs.epsilon), kargs.n, smem, diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp index 93c2833be4..58159142d0 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp @@ -25,8 +25,9 @@ struct Rmsnorm2dFwdPipelineOnePass using XResidualDataType = XDataType; using YResidualDataType = XDataType; - static constexpr bool kHasGamma = !std::is_same_v; - static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms; + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms; + static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM @@ -54,6 +55,7 @@ struct Rmsnorm2dFwdPipelineOnePass typename InvRmsWindow, typename SmoothScaleWindow, typename YScaleWindow, + typename UnquantYWindow, typename Epilogue> CK_TILE_DEVICE auto operator()(const XWindow& x_window_, const XResidualWindow& x_residual_window_, @@ -63,6 +65,7 @@ struct Rmsnorm2dFwdPipelineOnePass InvRmsWindow& inv_rms_window, const SmoothScaleWindow& sm_scale_window_, YScaleWindow& y_scale_window_, + UnquantYWindow& unquant_y_window, ComputeDataType epsilon, ck_tile::index_t row_size, void* smem, @@ -137,11 +140,26 @@ struct Rmsnorm2dFwdPipelineOnePass if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) { - Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem); + if constexpr(kSaveUnquant) + { + Epilogue{}( + unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem); + } + else + { + Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem); + } } else if constexpr(kFusedQuant == Rmsnorm2dFusedQuantEnum::DYNAMIC_QUANT) { - Epilogue{}(y_window_, y_scale_window_, rmsn, smem); + if constexpr(kSaveUnquant) + { + Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem); + } + else + { + Epilogue{}(y_window_, y_scale_window_, rmsn, smem); + } } else { diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp index baf56246f3..773df4f0f4 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp @@ -12,6 +12,7 @@ template ; using YDataType = remove_cvref_t; using InvRmsDataType = remove_cvref_t; + using UnquantYDataType = remove_cvref_t; using SmoothScaleDataType = remove_cvref_t; using YScaleDataType = remove_cvref_t; using BlockShape = remove_cvref_t; diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp index c29a6cb07d..4ca1dbc5da 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp @@ -54,6 +54,7 @@ struct Rmsnorm2dFwdPipelineTwoPass typename InvRmsWindow, typename SmoothScaleWindow, typename YScaleWindow, + typename UnquantYWindow, typename Epilogue> CK_TILE_DEVICE auto operator()(const XWindow& x_window_, const XResidualWindow& x_residual_window_, @@ -63,6 +64,7 @@ struct Rmsnorm2dFwdPipelineTwoPass InvRmsWindow& inv_rms_window, const SmoothScaleWindow& /*sm_scale_window_*/, YScaleWindow& /*y_scale_window*/, + UnquantYWindow& /*unquant_y_window*/, ComputeDataType epsilon, ck_tile::index_t row_size, void* smem, @@ -136,32 +138,51 @@ struct Rmsnorm2dFwdPipelineTwoPass ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + { + move_tile_window(y_residual_window, {0, -Block_N}); + } + else + { + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(x_residual_window, {0, -Block_N}); + } move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window}); // rmsnorm computation for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - auto x = load_tile(x_window); - auto x_resi = load_tile(x_residual_window); - auto acc = cast_tile(x); + auto acc = make_static_distributed_tensor( + decltype(load_tile(x_window))::get_tile_distribution()); - if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE || - kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD) + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) { - sweep_tile(x_resi, [&](auto idx) { - // compute x = x_resi + x - acc(idx) = type_convert(x_resi(idx)) + acc(idx); - }); + acc = cast_tile(load_tile(y_residual_window)); + move_tile_window(y_residual_window, {0, -Block_N}); + } + else + { + acc = cast_tile(load_tile(x_window)); + move_tile_window(x_window, {0, -Block_N}); + + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD) + { + auto x_resi = load_tile(x_residual_window); + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + acc(idx) = type_convert(x_resi(idx)) + acc(idx); + }); + move_tile_window(x_residual_window, {0, -Block_N}); + } } // load gamma (TODO: support no gamma?) const auto gamma = load_tile(gamma_window); // rmsnorm computation - auto rmsn = make_static_distributed_tensor(x.get_tile_distribution()); + auto rmsn = make_static_distributed_tensor( + decltype(load_tile(x_window))::get_tile_distribution()); sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]); @@ -176,8 +197,6 @@ struct Rmsnorm2dFwdPipelineTwoPass static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP); Epilogue{}(y_window, rmsn); - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(gamma_window, {-Block_N}); move_tile_window(y_window, {0, -Block_N}); } diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp index cb7beba291..152da60c01 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp @@ -39,6 +39,7 @@ template<> struct Rmsnorm2dFusedQuantEnumName @@ -46,6 +47,7 @@ struct Rmsnorm2dFwdTraits { static constexpr bool kPadN = kPadN_; static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kSaveUnquant = kSaveUnquant_; static constexpr bool kTwoPass = kTwoPass_; static constexpr Rmsnorm2dFusedAddEnum kFusedAdd = kFusedAdd_; static constexpr Rmsnorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index bea22da2c2..1c4dc8a445 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -64,6 +64,7 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, @@ -129,6 +130,7 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instance //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp index d317d270ce..0a85cde3bc 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -46,6 +46,7 @@ using device_grouped_conv_fwd_xdl_large_tensor_bf16_instances = std::tuple< // generic instance DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>, DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // clang-format on >; @@ -65,6 +66,7 @@ using device_grouped_conv_fwd_xdl_large_tensor_f16_instances = std::tuple< // generic instance DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>, DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // clang-format on >; diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc index f261164d61..74338ba383 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc +++ b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc @@ -3,7 +3,7 @@ TYPED_TEST(TestCkTileBatchedGemm, Basic) { constexpr int M = 256; - constexpr int N = 128; - constexpr int K = 128; + constexpr int N = 256; + constexpr int K = 512; this->Run(M, N, K); } diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 0f787b718d..0af3ef3b34 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -28,17 +28,9 @@ class TestCkTileBatchedGemm : public ::testing::Test void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { - // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr int kBlockPerCu = 1; - - // This part comes from the Codegen - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -46,72 +38,144 @@ class TestCkTileBatchedGemm : public ::testing::Test constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; + constexpr ck_tile::index_t K_Warp_Tile = 16; - using CodegenGemmShape = + constexpr bool DoubleSmemBuffer = false; + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; - using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; - using CodegenGemmTraits = - ck_tile::TileGemmTraits; + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; - using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + float ave_time{0}; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = - ck_tile::BatchedGemmKernel; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - auto kargs = Kernel::MakeKernelArgs(args); + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - constexpr dim3 blocks = Kernel::BlockSize(); + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - if(s.log_level_ > 0) + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" + << tail_num << "\" which is not supported! PrefetchStages: " + << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + else + { + std::ostringstream err; + err << "Num K loop must be larger than number of prefetech stages." + << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); } - - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: void Run(const int M, const int N, const int K, - int StrideA = 128, - int StrideB = 128, - int StrideC = 128, - const int BatchStrideA = 32768, - const int BatchStrideB = 16384, - const int BatchStrideC = 32768, - const int BatchCount = 16) + int StrideA = 512, + int StrideB = 512, + int StrideC = 256, + const int BatchStrideA = 131072, + const int BatchStrideB = 131072, + const int BatchStrideC = 65536, + const int BatchCount = 8) { using namespace ck_tile::literals; diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc b/test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc index 68c4693bb3..9f6b66c92b 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc @@ -2,7 +2,7 @@ TYPED_TEST(TestCkTileGroupedGemm, Basic) { - const int group_count = 16; + const int group_count = 8; std::vector Ms; std::vector Ns; std::vector Ks; @@ -13,8 +13,8 @@ TYPED_TEST(TestCkTileGroupedGemm, Basic) for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); - Ns.push_back(128 + 128 * i); - Ks.push_back(128 + 64 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(256 + 64 * i); stride_As.push_back(Ks[i]); stride_Bs.push_back(Ks[i]); diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index cd94d0b867..b125d19762 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -44,65 +44,10 @@ class TestCkTileGroupedGemm : public ::testing::Test static const ck_tile::index_t K_Warp_Tile = 8; }; - using CodegenGemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - - template - using CodegenGemmTraits = ck_tile::TileGemmTraits; - - template - using CodegenPipelineProblem = - ck_tile::GemmPipelineProblem>; - - template - using CodegenGemmPipeline = - ck_tile::GemmPipelineAGmemBGmemCRegV1>; - - template - using GemmEpilogue = ck_tile::CShuffleEpilogue::BlockSize, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GroupedGemKernelParam::M_Warp, - GroupedGemKernelParam::N_Warp, - GroupedGemKernelParam::M_Warp_Tile, - GroupedGemKernelParam::N_Warp_Tile, - GroupedGemKernelParam::K_Warp_Tile, - CodegenPipelineProblem::TransposeC>>; - - template - using Kernel = ck_tile::GroupedGemmKernel, - GemmEpilogue>; - - using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; - std::size_t GetWorkspaceSize(const std::vector& gemm_descs) + using grouped_gemm_kargs = ck_tile::GemmHostArgs; + std::size_t get_workspace_size(const std::vector& gemm_descs) { - return Kernel::GetWorkSpaceSize(gemm_descs); + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } template @@ -110,35 +55,140 @@ class TestCkTileGroupedGemm : public ::testing::Test const ck_tile::stream_config& s, void* p_workspace_) { - using GroupedGemmKernel = Kernel; + constexpr bool DoubleSmemBuffer = false; + constexpr bool TransposeC = false; - auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs); + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; - const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs); - constexpr dim3 blocks = GroupedGemmKernel::BlockSize(); + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; - ck_tile::hip_check_error(hipMemcpyWithStream( - p_workspace_, - arguments.data(), - arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg), - hipMemcpyHostToDevice, - s.stream_id_)); + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; - if(s.log_level_ > 0) + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile; + const ck_tile::index_t K_split = + (gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + + const dim3 grids = Kernel::GridSize(gemm_descs); + constexpr dim3 blocks = Kernel::BlockSize(); + + ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(p_workspace_), + gemm_descs.size())); + return ave_time; + }; + + if(has_hot_loop) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" + << tail_num << "\" which is not supported! PrefetchStages: " + << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + else + { + std::ostringstream err; + err << "Num K loop must be larger than number of prefetech stages." + << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); } - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - GroupedGemmKernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(p_workspace_), - gemm_descs.size())); } public: @@ -243,12 +293,14 @@ class TestCkTileGroupedGemm : public ::testing::Test const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + // TODO add support for kbatch > 1 + static constexpr ck_tile::index_t k_batch = 1; gemm_descs.push_back( - {p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + {p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } ck_tile::DeviceMem gemm_workspace; - gemm_workspace.Realloc(GetWorkspaceSize(gemm_descs)); + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); invoke_grouped_gemm( gemm_descs, ck_tile::stream_config{nullptr, false}, gemm_workspace.GetDeviceBuffer());