From 705804d9bf87e1e2fca23c0af231efcdebf76efb Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Sat, 30 Aug 2025 09:54:18 -0400 Subject: [PATCH 001/404] Restructure the Tile Engine to have faster build time and clear config report (#2747) * Making edits to identify individual compilation issues. * Minor fix for blob txt files not being created. * Fixing compilation issues. * Fixing ordering bug. * Adding python profiling functionality. * Setting individual build as default. * Setting gpu target filtering for tile engine to gfx90a, gfx942 and gfx950. * update the default running parameters and settings * Fixing bug with benchmarking, shifting file generation to build instead of config. * Updating fixes. * Fixing json output and parsing. * Disable ccache for tile engine gemm ops because we dont need it. * Removing duplicate type definition. * Improving json printing. * Add the flexibility of different layout and more warp tile support * Fix extra flag in name of individual kernels. * Fixing bug with booleans. * Solve the first patch of the post merge conflict * Compilation fixes, and cosmetic improvements. * Yet again compilation fixes after latest changes from develop. * Fixing python benchmarking script. --------- Co-authored-by: Vidyasagar Ananthan Co-authored-by: Vidyasagar Ananthan --- script/cmake-ck-dev.sh | 13 +- tile_engine/ops/gemm/CMakeLists.txt | 418 +++-- tile_engine/ops/gemm/README.md | 495 ++++- tile_engine/ops/gemm/benchmark_gemm.cpp | 68 - tile_engine/ops/gemm/benchmark_gemm.hpp | 19 +- .../ops/gemm/benchmark_gemm_single.cpp | 160 ++ tile_engine/ops/gemm/codegen_utils.py | 8 + tile_engine/ops/gemm/configs/benchmark.json | 12 +- .../ops/gemm/configs/default_config.json | 200 +- tile_engine/ops/gemm/gemm_benchmark.py | 721 ++++++++ tile_engine/ops/gemm/gemm_common.hpp | 197 ++ tile_engine/ops/gemm/gemm_host_api.hpp | 223 --- tile_engine/ops/gemm/gemm_instance_builder.py | 1612 +++++++++-------- tile_engine/ops/gemm/gemm_profiler.hpp | 37 +- tile_engine/ops/gemm/test_benchmark.sh | 102 ++ tile_engine/ops/gemm/test_validation.py | 143 ++ tile_engine/ops/gemm/validation_utils.py | 342 ++++ 17 files changed, 3361 insertions(+), 1409 deletions(-) delete mode 100644 tile_engine/ops/gemm/benchmark_gemm.cpp create mode 100644 tile_engine/ops/gemm/benchmark_gemm_single.cpp create mode 100755 tile_engine/ops/gemm/gemm_benchmark.py create mode 100644 tile_engine/ops/gemm/gemm_common.hpp delete mode 100644 tile_engine/ops/gemm/gemm_host_api.hpp mode change 100755 => 100644 tile_engine/ops/gemm/gemm_instance_builder.py create mode 100755 tile_engine/ops/gemm/test_benchmark.sh create mode 100644 tile_engine/ops/gemm/test_validation.py create mode 100644 tile_engine/ops/gemm/validation_utils.py diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index b93555901e..217ec998bd 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -25,13 +25,20 @@ if [ $# -ge 1 ]; then GPU_TARGETS=$1 shift 1 echo "GPU targets provided: $GPU_TARGETS" + REST_ARGS=("$@") ;; *) - echo "No GPU targets provided, using default targets: $GPU_TARGETS" + echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" + GPU_TARGETS="gfx908;gfx90a;gfx942" + shift 1 + REST_ARGS=("$@") ;; esac else - echo "No GPU targets provided, using default targets: $GPU_TARGETS" + echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" + GPU_TARGETS="gfx908;gfx90a;gfx942" + shift 1 + REST_ARGS=("$@") fi cmake \ @@ -43,5 +50,5 @@ cmake -D GPU_TARGETS=$GPU_TARGETS \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ -$@ \ +"${REST_ARGS[@]}" \ \ ${MY_PROJECT_SOURCE} diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index 42c114b499..d52351af2d 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -1,169 +1,295 @@ - set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)") set(GEMM_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)") +set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") +option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF) -function(build_gemm_for_datatype datatype layout) - # Filter GPU targets to only gfx90a, gfx942, and gfx950 - set(GEMM_GPU_TARGETS "") - set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") - - foreach(target IN LISTS SUPPORTED_GPU_TARGETS) - if(target IN_LIST DESIRED_TARGETS) - list(APPEND GEMM_GPU_TARGETS ${target}) - endif() - endforeach() - - # Skip compilation if no matching targets found - if(NOT GEMM_GPU_TARGETS) - message(WARNING "Skipping Tile Engine GEMM compilation: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +# Store the directory path for use in functions +set(GEMM_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) + +# Function to create individual GEMM targets +function(create_individual_gemm_target datatype layout trait tile_config config_json) + # Use the parent scope GEMM_GPU_TARGETS_INDIVIDUAL variable + if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets") return() endif() - message(STATUS "Building GEMM for GPU targets: ${GEMM_GPU_TARGETS}") + # Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k + # First split by underscore to get three groups + string(REPLACE "_" ";" config_groups ${tile_config}) + list(GET config_groups 0 tile_dims) # e.g., 256x256x32 + list(GET config_groups 1 warp_dims) # e.g., 4x1x1 + list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16 + # Parse tile dimensions + string(REPLACE "x" ";" tile_parts ${tile_dims}) + list(GET tile_parts 0 tile_m) + list(GET tile_parts 1 tile_n) + list(GET tile_parts 2 tile_k) + + # Parse warp dimensions + string(REPLACE "x" ";" warp_parts ${warp_dims}) + list(GET warp_parts 0 warp_m) + list(GET warp_parts 1 warp_n) + list(GET warp_parts 2 warp_k) + + # Parse warp tile dimensions + string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims}) + list(GET warp_tile_parts 0 warp_tile_m) + list(GET warp_tile_parts 1 warp_tile_n) + list(GET warp_tile_parts 2 warp_tile_k) + + set(target_name "benchmark_gemm_${datatype}_${layout}_${trait}_${tile_config}") set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") - - # Comment this if-else block when using user_provided_config - if(layout STREQUAL "rcr") - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") - else() - set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/custom_ci_config.json") - endif() - - # uncomment this if you want to use user_provided_config.json - # set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json") - # Generate kernel list + # Generate the single instance header for this kernel + set(instance_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + + # Add custom command to generate the header file at build time + add_custom_command( + OUTPUT ${instance_header} + COMMAND ${Python3_EXECUTABLE} ${GEMM_SOURCE_DIR}/gemm_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${config_json} + --gen_single + --kernel_name "gemm_${datatype}_${layout}_${trait}_${tile_config}" + --tile_config "${tile_config}" + --trait_combo "${trait}" + DEPENDS ${GEMM_SOURCE_DIR}/gemm_instance_builder.py ${config_json} + COMMENT "Generating ${instance_header}" + ) + + # Create the executable + add_executable(${target_name} + ${GEMM_SOURCE_DIR}/benchmark_gemm_single.cpp + ${instance_header} + ) + + # Set GPU architectures + set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL}) + + # Set compile definitions + target_compile_definitions(${target_name} PRIVATE + GEMM_SINGLE_INSTANCE_HPP="${instance_header}" + ) + + # Include directories + target_include_directories(${target_name} PRIVATE + ${GEMM_SOURCE_DIR} + ${working_path} + ) + + # Compile options + target_compile_options(${target_name} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + -include ${instance_header} + ) + + # Add to collection targets + add_dependencies(benchmark_gemm_all ${target_name}) + add_dependencies(benchmark_gemm_${datatype} ${target_name}) + add_dependencies(benchmark_gemm_${layout} ${target_name}) + add_dependencies(benchmark_gemm_${datatype}_${layout} ${target_name}) + + # Add to trait-specific targets + string(REPLACE "_" ";" trait_parts ${trait}) + list(GET trait_parts 0 pipeline) + list(GET trait_parts 1 epilogue) + list(GET trait_parts 2 scheduler) + + add_dependencies(benchmark_gemm_${pipeline} ${target_name}) + add_dependencies(benchmark_gemm_${epilogue} ${target_name}) + add_dependencies(benchmark_gemm_${scheduler} ${target_name}) +endfunction() + +# Function to build individual GEMM targets +function(build_individual_gemm_targets datatype layout) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Choose config file + # Priority order: + # 1. Environment variable GEMM_CONFIG_FILE + # 2. CMake variable GEMM_CONFIG_FILE + # 3. Default based on layout + + # Check environment variable first + if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "") + set(config_filename "$ENV{GEMM_CONFIG_FILE}") + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}") + message(STATUS " Using config from environment variable: ${config_filename}") + elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "") + # Use CMake variable if set + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}") + message(STATUS " Using custom config: ${GEMM_CONFIG_FILE}") + else() + # Use default config for all layouts + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") + message(STATUS " Using default config for layout ${layout}") + endif() + + # Check if config file exists + if(NOT EXISTS ${json_blob}) + message(FATAL_ERROR "Config file not found: ${json_blob}") + endif() + + # Determine number of workers for parallel generation + if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + else() + # Use processor count but limit to avoid memory issues + cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES) + math(EXPR num_workers "${num_cores}") + if(num_workers GREATER 8) + set(num_workers 8) + endif() + endif() + + # Generate individual kernel files using parallel version + message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") + message(STATUS " Working path: ${working_path}") + message(STATUS " Config file: ${json_blob}") + message(STATUS " Python executable: ${Python3_EXECUTABLE}") + message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py") + + # Create working directory first + file(MAKE_DIRECTORY ${working_path}) + + # First, just list the kernels (fast operation) + message(STATUS " Listing kernel configurations...") execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py --working_path ${working_path} --datatype ${datatype} --layout ${layout} --config_json ${json_blob} - --list_blobs + --list_kernels + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} RESULT_VARIABLE ret + OUTPUT_VARIABLE list_output + ERROR_VARIABLE list_error ) + if(NOT ret EQUAL 0) - message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${ret}") + message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") endif() - - file(STRINGS "${working_path}/gemm_instance_blobs.txt" codegen_blobs) - file(STRINGS "${working_path}/gemm_instance_blobs_range.txt" codegen_blobs_range) - # Generate the blobs - add_custom_command( - OUTPUT ${codegen_blobs} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py - --working_path "${working_path}" - --datatype ${datatype} - --layout ${layout} - --config_json "${json_blob}" - --gen_blobs - COMMENT "Generating GEMM instance sources for ${datatype} ${layout}" - ) - add_custom_target(gemm_gen_${datatype}_${layout} DEPENDS ${codegen_blobs}) - - set(intermediate_libs) - list(LENGTH codegen_blobs codegen_blobs_len) - - foreach(blob IN LISTS codegen_blobs_range) - string(STRIP "${blob}" stripped_blob) - separate_arguments(spilit_blob UNIX_COMMAND "${stripped_blob}") - # Each line is: - list(GET spilit_blob 0 name) - list(GET spilit_blob 1 first) - list(GET spilit_blob 2 last) - math(EXPR total_files "${last} - ${first}") - if(total_files EQUAL 0) - continue() # nothing for this trait - endif() - - # Object libraries (chunked) per trait - set(sub_intermediate_libs) - set(chunk_size 3) - math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}") - math(EXPR num_chunks_minus_1 "${num_chunks} - 1") - - foreach(i RANGE 0 ${num_chunks_minus_1}) - math(EXPR start "${first} + ${i} * ${chunk_size} ") - math(EXPR end "${start} + ${chunk_size} - 1") - - set(chunk_files) - foreach(j RANGE ${start} ${end}) - if(j LESS ${last} AND j LESS ${codegen_blobs_len}) - list(GET codegen_blobs ${j} f) - list(APPEND chunk_files "${f}") - endif() - endforeach() - - #list(LENGTH chunk_files chunk_files_len) - #if(chunk_files_len AND chunk_files_len GREATER 1) - if(chunk_files) - set(sub_intermediate_lib_name "gemm_objlib_${name}_${i}_${datatype}_${layout}") - add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files}) - set_property(TARGET ${sub_intermediate_lib_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) - list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name}) - endif() - + # Read kernel count + if(EXISTS ${working_path}/gemm_kernel_count.txt) + file(READ ${working_path}/gemm_kernel_count.txt kernel_count) + string(STRIP "${kernel_count}" kernel_count) + message(STATUS " Found ${kernel_count} kernel configurations") + else() + message(FATAL_ERROR "Kernel count file not found") + endif() + + # Read kernel list and create targets + if(EXISTS ${working_path}/gemm_kernel_list.txt) + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + foreach(line IN LISTS kernel_lines) + # Parse line: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Create individual target + create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}") endforeach() - - # ------------------ Bundle the object libs into one static lib --------- - #list(LENGTH sub_intermediate_libs sub_intermediate_libs_len) - #if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1) - if(sub_intermediate_libs) - set(intermediate_lib_name "gemm_staticlib_${name}_${datatype}_${layout}") - # Collect the $ expressions - - set(obj_exprs) - foreach(objlib IN LISTS sub_intermediate_libs) - list(APPEND obj_exprs $) - endforeach() - - add_library(${intermediate_lib_name} STATIC ${obj_exprs}) - add_dependencies(${intermediate_lib_name} gemm_gen_${datatype}_${layout}) - set_property(TARGET ${intermediate_lib_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) - #foreach(objlib IN LISTS sub_intermediate_libs) - # target_sources(${intermediate_lib_name} PRIVATE $) - #endforeach() - list(APPEND intermediate_libs ${intermediate_lib_name}) - endif() - - endforeach() - - # Interface library for instances - add_library(gemm_template_instances_${datatype}_${layout} INTERFACE) - add_dependencies(gemm_template_instances_${datatype}_${layout} gemm_gen_${datatype}_${layout}) - target_link_libraries(gemm_template_instances_${datatype}_${layout} INTERFACE ${intermediate_libs}) - target_include_directories(gemm_template_instances_${datatype}_${layout} INTERFACE - ${CMAKE_CURRENT_LIST_DIR} - "${working_path}" - ) - set_target_properties(gemm_template_instances_${datatype}_${layout} PROPERTIES LINKER_LANGUAGE CXX) - - # Host API interface library - add_library(gemm_host_api_${datatype}_${layout} INTERFACE) - target_link_libraries(gemm_host_api_${datatype}_${layout} INTERFACE gemm_template_instances_${datatype}_${layout}) - target_include_directories(gemm_host_api_${datatype}_${layout} INTERFACE - ${CMAKE_CURRENT_LIST_DIR} - "${working_path}" - ) - - - # Executable per datatype - set(exec_name "benchmark_gemm_${datatype}_${layout}") - add_executable(${exec_name} benchmark_gemm.cpp) - set_property(TARGET ${exec_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS}) - target_link_libraries(${exec_name} PRIVATE gemm_host_api_${datatype}_${layout}) - target_compile_options(${exec_name} PRIVATE - -Wno-undefined-func-template - -Wno-float-equal - --offload-compress - ) + else() + message(FATAL_ERROR "Kernel list file not found") + endif() endfunction() -# Process each datatype in isolation -foreach(dt IN LISTS GEMM_DATATYPE) - foreach(l IN LISTS GEMM_LAYOUT) - build_gemm_for_datatype(${dt} ${l}) - endforeach() +# Main build logic - Only individual builds supported +message(STATUS "=== Starting Tile Engine GEMM Configuration ===") +message(STATUS "GEMM_DATATYPE: ${GEMM_DATATYPE}") +message(STATUS "GEMM_LAYOUT: ${GEMM_LAYOUT}") +message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + +# Filter GPU targets to only gfx90a, gfx942, and gfx950 +set(GEMM_GPU_TARGETS_INDIVIDUAL "") +set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") + +foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target}) + message(STATUS " Adding GPU target: ${target}") + endif() endforeach() + +# Skip build if no matching targets found +if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +else() + message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}") + + # Enable parallel compilation optimizations + # Set up job pools for better parallel compilation control + set_property(GLOBAL PROPERTY JOB_POOLS + compile_heavy=4 # Limit heavy compilations to prevent OOM + compile_normal=16 # Allow more parallel normal compilations + ) + + # Enable compiler cache if available and explicitly requested + # Disabled by default due to permission issues in CI environments + if(ENABLE_CCACHE_GEMM) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) + message(STATUS "Using ccache for faster compilation") + else() + message(WARNING "ccache requested but not found") + endif() + else() + message(STATUS "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)") + endif() + + # Create master collection targets + add_custom_target(benchmark_gemm_all) + + # Create datatype collection targets + foreach(dt IN LISTS GEMM_DATATYPE) + add_custom_target(benchmark_gemm_${dt}) + endforeach() + + # Create layout collection targets + foreach(l IN LISTS GEMM_LAYOUT) + add_custom_target(benchmark_gemm_${l}) + endforeach() + + # Create combined collection targets + foreach(dt IN LISTS GEMM_DATATYPE) + foreach(l IN LISTS GEMM_LAYOUT) + add_custom_target(benchmark_gemm_${dt}_${l}) + endforeach() + endforeach() + + # Create trait-based collection targets + # These are common trait components used across all GEMM kernels + set(GEMM_PIPELINES "mem;compv3;compv4") + set(GEMM_EPILOGUES "default;cshuffle") + set(GEMM_SCHEDULERS "intrawave;interwave") + + foreach(pipeline IN LISTS GEMM_PIPELINES) + add_custom_target(benchmark_gemm_${pipeline}) + endforeach() + + foreach(epilogue IN LISTS GEMM_EPILOGUES) + add_custom_target(benchmark_gemm_${epilogue}) + endforeach() + + foreach(scheduler IN LISTS GEMM_SCHEDULERS) + add_custom_target(benchmark_gemm_${scheduler}) + endforeach() + + # Build individual targets for each datatype/layout combination + foreach(dt IN LISTS GEMM_DATATYPE) + foreach(l IN LISTS GEMM_LAYOUT) + build_individual_gemm_targets(${dt} ${l}) + endforeach() + endforeach() +endif() diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index 79152a1a0d..01ffbb6da7 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -1,113 +1,442 @@ -# GEMM Matrix Multiplication +# CK Tile Engine GEMM Operations -CK Tile Engine GEMM is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues. +## Overview -# Kernel Configurations +The CK Tile Engine GEMM module provides a comprehensive system for generating, building, and benchmarking GEMM (General Matrix Multiplication) kernels with various configurations. It supports multiple data types, layouts, and optimization strategies. The system has evolved from a monolithic build approach (where all kernels compile into a single executable) to a more flexible individual kernel compilation system, providing better build parallelism and targeted testing capabilities. -Users can specify custom kernel configurations such as tile size, warp size, padding, pipeline, scheduler, and epilogue in the config file. This allows building only for selected configurations, significantly reducing build time. -For reference please see `./configs/user_provided_config.json`. +## Table of Contents +1. [Build System Architecture](#build-system-architecture) +2. [Build Instructions](#build-instructions) +3. [Running Benchmarks](#running-benchmarks) +4. [Configuration System](#configuration-system) +5. [Scripts and Tools](#scripts-and-tools) +6. [Command Line Options](#command-line-options) +7. [Understanding Kernel Names](#understanding-kernel-names) +8. [Troubleshooting](#troubleshooting) +9. [Performance Tips](#performance-tips) -The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json` +## Build System Architecture -If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark. +### Individual Kernel Compilation (New Approach) + +The new tile engine benchmark system compiles each kernel configuration into a separate executable. This provides: +- Better build parallelism +- Faster incremental builds +- More targeted testing +- Easier debugging of specific configurations + +Each benchmark executable follows the naming pattern: +``` +benchmark_gemm____ +``` + +### Monolithic Build (Legacy Approach) + +The original system compiles all kernels into a single executable (`benchmark_gemm_[Datatype]_[Layout]`), which can then be filtered at runtime using command-line arguments. ## Build Instructions -``` bash -# in the root of composable kernel create build directory + +### Prerequisites +- ROCm installation +- CMake 3.16 or higher +- C++17 compatible compiler + +### Basic Build + +```bash +# In the root of composable kernel, create build directory mkdir build && cd build -# build composable kernel -# replace [Arch] with the appropriate architecture or leave blank and -# replace [Datatype1;Datatype2;...] in comma separated datatypes string (possible datatypes are [fp8, bf8, int8, fp16, bf16]) -# replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr]) -../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" -DGEMM_LAYOUT="[Layout1;Layout2]" -# generate different executable for each passed datatype + +# Configure with specific datatypes and layouts +# Replace [Arch] with your GPU architecture (e.g., gfx90a, gfx942) +# Replace [Datatype1;Datatype2;...] with datatypes (fp8, bf8, int8, fp16, bf16, fp32, fp64) +# Replace [Layout1;Layout2;...] with layouts (rcr, rrr, crr, ccr) +../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" -DGEMM_LAYOUT="[Layout1;Layout2]" + +# Build specific benchmarks make benchmark_gemm_[Datatype1]_[Layout1] -j -make benchmark_gemm_[Datatype1]_[Layout2] -j -make benchmark_gemm_[Datatype2]_[Layout1] -j -make benchmark_gemm_[Datatype2]_[Layout2] -j -``` -`benchmark_gemm_[Datatype]_[Layout]` will be located in the `./bin/` directory. - -`benchmark_gemm_[Datatype]_[Layout]` must be rebuilt everytime if configuration file is modified. - -``` bash -rm -rf tile_engine/ && make benchmark_gemm_[Datatypes]_[Layout] -j # rebuild ``` -## For eaxmple build for gfx942 for fp8 and fp16 datatypes with rcr layout -``` bash +### Configuration Options + +The build system supports several configuration options: + +#### Using Custom Config Files +```bash +# Method 1: CMake variable (config file must be in configs/ directory) +cmake -DGEMM_CONFIG_FILE=my_custom_config.json ... + +# Method 2: Environment variable (takes precedence over CMake variable) +export GEMM_CONFIG_FILE=my_custom_config.json +cmake ... +``` + +#### Config File Priority Order +1. **Environment variable** `GEMM_CONFIG_FILE` (highest priority) +2. **CMake variable** `GEMM_CONFIG_FILE` +3. **Default config** (default_config.json for all layouts) + +**Note**: All custom config files must be placed in the `tile_engine/ops/gemm/configs/` directory. + +### Example Build Commands + +```bash +# Build for gfx942 with fp8 and fp16 datatypes, rcr layout mkdir build && cd build -../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -DGEMM_LAYOUT="rcr" +../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -DGEMM_LAYOUT="rcr;ccr;rrr;crr" make benchmark_gemm_fp8_rcr -j make benchmark_gemm_fp16_rcr -j ``` -## benchmark_gemm inputs +### Building Individual Kernels + +```bash +# Build a specific kernel configuration +make benchmark_gemm_fp8_rcr_compv4_default_intrawave_False_False_False_False_256x256x32_1x4x1_32x32x32 + +# Build all fp16 benchmarks in parallel +make -j$(nproc) $(make help | grep benchmark_gemm_fp16 | awk '{print $2}') ``` - -m The value for m dimension. Default is 3840. - -n The value for n dimension. Default is 4096. - -k The value for k dimension. Default is 2048. - -stride_a The stride value for tensor A. Default is 0. - -stride_b The stride value for tensor B. Default is 0. - -stride_c The stride value for tensor C Default is 0. - -split_k The split value for k dimension. Default is 1. - -v The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 2, validation on GPU. - -log Wether output kernel instance information or not. Possible values are true or false. Default is false. - -warmup The number of iterations before benchmark the kernel. Default is 50. - -repeat The number of iterations to benchmark the kernel. Default is 100. - -timer Whether if the timer is gpu timer or not. Possible values are true or false. Default is true. - -init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random. - -flush_cache To flush cache in between different runs.Possible values are true or false. Default is false. - -rotating_count count to flush cache. Default is 5. - -metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency. - -csv_filename The filename of benchmark result. Default is gemm_kernel. - -structured_sparsity whether use sparsity kernel or not. Possible values are true or false. Default is false. - -pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3. - -epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle. - -pad_m Whether pad or not in m direction. Possible values are true or false. Default is false. - -pad_n Whether pad or not in n direction. Possible values are true or false. Default is false. - -pad_k Whether pad or not in k direction. Possible values are true or false. Default is false. -Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in user_provided_config.json +### Rebuilding After Configuration Changes + +If you modify the configuration file, you must rebuild: +```bash +rm -rf tile_engine/ && make benchmark_gemm_[Datatype]_[Layout] -j ``` -Note: In `./configs/user_provided_config.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above. -## Example +## Running Benchmarks -The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes. +### Individual Kernel Execution + +```bash +cd /path/to/build/directory +./bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 \ + -m=512 -n=512 -k=512 -verify=1 +``` + +### Monolithic Executable (Legacy) + +```bash +# Run specific pipeline/scheduler/epilogue combination +./bin/benchmark_gemm_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=default +``` + +### Automated Testing + +Use the provided test script to run multiple benchmarks: +```bash +cd /path/to/composable_kernel/tile_engine/ops/gemm +./test_benchmark.sh [build_directory] +``` + +## Configuration System + +### Configuration Files + +The system uses JSON configuration files to specify kernel parameters: + +- `configs/default_config.json` - Default configurations for various datatypes +- `configs/user_provided_config.json` - User-customizable configurations + +### Configuration Structure ```json -{ - /// other parameters /// - - "tile_m": { - "values": [256] +{ + "tile_config": { + "tile_m": {"values": [256, 128]}, + "tile_n": {"values": [256, 128]}, + "tile_k": {"values": [64, 32]}, + "warp_m": {"values": [2, 4]}, + "warp_n": {"values": [2, 1]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [32, 16]}, + "warp_tile_n": {"values": [32, 16]}, + "warp_tile_k": {"values": [16, 32]} }, - "tile_n": { - "values": [256] - }, - "tile_k": { - "values": [64, 32] - }, - - /// other parameters /// - - "pipeline": { - "values": ["compv3", "compv4", "mem"] - }, - "scheduler": { - "values": ["intrawave", "interwave"] - }, - "epilogue": { - "values": ["default", "cshuffle"] + "trait_config": { + "pipeline": {"values": ["compv3", "compv4", "mem"]}, + "scheduler": {"values": ["intrawave", "interwave"]}, + "epilogue": {"values": ["default", "cshuffle"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} } } ``` -At runtime, a specific subset of the generated kernels can be selected using command-line arguments. -``` bash -./bin/benchmark_gemm_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=default -``` -The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and default epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings. +## Scripts and Tools +### Python Scripts + +#### gemm_instance_builder.py +**Purpose**: Main kernel instance generation script that creates C++ kernel implementations based on configuration files. + +**Key Features**: +- Generates individual kernel header files for separate compilation +- Supports multiple data types (fp16, fp8, bf16, fp32, fp64) +- Validates tile configurations for correctness +- Creates CMake integration files + +**Usage**: +```bash +python gemm_instance_builder.py \ + --working_path ./generated \ + --datatype fp16 \ + --layout rcr \ + --config_json configs/user_provided_config.json \ + --gen_individual +``` + +#### gemm_instance_builder_parallel.py +**Purpose**: Parallel version of the instance builder for faster generation of multiple kernel configurations. + +**Features**: +- Multi-threaded kernel generation +- Improved performance for large configuration spaces + +#### validation_utils.py +**Purpose**: Provides comprehensive validation functions for kernel configurations. + +**Key Functions**: +- `is_tile_config_valid()` - Validates tile dimensions and alignments +- `is_trait_combination_valid()` - Checks if pipeline/epilogue/scheduler combinations are supported +- `validate_warp_tile_combination()` - GPU-specific warp tile validation +- `validate_lds_capacity()` - Ensures configurations fit in LDS memory + +**Validation Checks**: +- Dimension alignment (tile dimensions must be divisible by warp dimensions) +- LDS capacity constraints +- GPU-specific warp tile support +- Unsupported trait combinations + +#### test_validation.py +**Purpose**: Test suite for the validation logic to ensure correctness. + +**Usage**: +```bash +python test_validation.py +``` + +**Tests**: +- Warp tile combination validation +- Trait combination validation +- Full tile configuration validation + +#### gemm_benchmark.py +**Purpose**: Python script for running and analyzing GEMM benchmarks. + +**Features**: +- Automated benchmark execution +- Performance data collection +- Result analysis and reporting + +#### json_config.py +**Purpose**: Configuration file parsing and management. + +**Features**: +- JSON configuration loading +- Default configuration handling +- Configuration validation + +#### codegen_utils.py +**Purpose**: Utility functions for code generation. + +**Features**: +- Template processing +- Code formatting utilities +- File generation helpers + +### Shell Scripts + +#### test_benchmark.sh +**Purpose**: Automated benchmark testing script that finds and runs all built benchmark executables. + +**Features**: +- Automatic build directory detection +- Batch execution of multiple benchmarks +- CSV result collection +- Colored output for easy reading +- Example command generation + +**Usage**: +```bash +# Auto-detect build directory +./test_benchmark.sh + +# Specify build directory +./test_benchmark.sh /path/to/build/directory +``` + +**What it does**: +1. Finds all benchmark executables in the build directory +2. Runs each with multiple problem sizes (512, 1024, 2048) +3. Performs GPU verification +4. Saves results to timestamped CSV file +5. Provides summary statistics + +## Command Line Options + +All benchmark executables support the following options: + +### Matrix Dimensions +- `-m=` - M dimension (default: 3840) +- `-n=` - N dimension (default: 4096) +- `-k=` - K dimension (default: 2048) + +### Strides +- `-stride_a=` - Stride for matrix A (default: 0, auto-calculated) +- `-stride_b=` - Stride for matrix B (default: 0, auto-calculated) +- `-stride_c=` - Stride for matrix C (default: 0, auto-calculated) + +### Verification +- `-verify=<0|1|2>` - Verification mode + - 0: No verification (default) + - 1: CPU verification + - 2: GPU verification + +### Performance Testing +- `-warmup=` - Warmup iterations (default: 50) +- `-repeat=` - Benchmark iterations (default: 100) +- `-timer=` - Use GPU timer (default: true) +- `-flush_cache=` - Flush cache between runs (default: true) +- `-rotating_count=` - Cache rotation count (default: 1000) + +### Initialization +- `-init=<0|1|2>` - Tensor initialization method + - 0: Random values [-1, 1] (default) + - 1: Linear sequence (i % 17) + - 2: Constant value (1.0) + +### Output Options +- `-log=` - Enable verbose logging (default: false) +- `-metric=<0|1|2>` - Performance metric + - 0: Latency in ms (default) + - 1: TFLOPS + - 2: Bandwidth in GB/s +- `-json_output=` - JSON format output (default: false) +- `-csv_filename=` - Save results to CSV +- `-csv_format=` - CSV format (default: comprehensive) + +### Advanced Options +- `-split_k=` - Split-K factor (default: 1) +- `-structured_sparsity=` - Enable structured sparsity (default: false) +- `-pipeline=` - Pipeline type (default: compv3) +- `-scheduler=` - Scheduler type (default: intrawave) +- `-epilogue=` - Epilogue type (default: cshuffle) +- `-pad_m=` - Pad M dimension (default: false) +- `-pad_n=` - Pad N dimension (default: false) +- `-pad_k=` - Pad K dimension (default: false) +- `-persistent=` - Use persistent kernel (default: false) + +## Understanding Kernel Names + +The kernel naming convention encodes the configuration: + +``` +benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 + ^^^^ ^^^ ^^^^^^ ^^^^^^^ ^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^ ^^^^^^^ ^^^^^^^^^ + | | | | | | | | | + | | | | | Padding & flags | | Warp tile + | | | | Scheduler | Thread tile + | | | Epilogue Block tile + | | Pipeline + | Layout (Row-Column-Row) + Data type +``` + +### Components: +- **Data type**: fp16, fp32, bf16, fp8, bf8, int8 +- **Layout**: rcr (Row-Column-Row), rrr, crr, ccr +- **Pipeline**: mem, compv3, compv4 +- **Epilogue**: default, cshuffle +- **Scheduler**: intrawave, interwave +- **Flags**: pad_m, pad_n, pad_k, persistent (4 boolean flags) +- **Tile sizes**: BlockTile x ThreadTile x WarpTile + +## Troubleshooting + +### Common Issues + +1. **Kernel not found** + - Ensure the specific benchmark executable is built + - Check the build directory bin/ folder + +2. **Verification failures** + - Try GPU verification (-verify=2) which may be more accurate + - Check data type compatibility + - Verify stride calculations + +3. **Build failures** + - Check GPU architecture compatibility + - Ensure ROCm is properly installed + - Verify configuration file syntax + +4. **Performance variations** + - Increase warmup iterations + - Disable CPU frequency scaling + - Use GPU timer for accurate measurements + +### Debug Options + +Enable verbose logging: +```bash +./bin/benchmark_gemm_... -log=true -verify=1 +``` + +Test validation logic: +```bash +python test_validation.py +``` + +## Performance Tips + +1. **Optimal Problem Sizes**: Use sizes that are multiples of tile dimensions +2. **Warmup**: Use at least 50-100 warmup iterations +3. **GPU Timer**: Always use `-timer=true` for accurate measurements +4. **Cache Management**: Enable cache flushing for consistent results +5. **Thread Affinity**: Set CPU affinity to reduce variation + +## Integration Examples + +### Python Integration + +```python +import subprocess +import json + +# Run benchmark with JSON output +result = subprocess.run([ + './bin/benchmark_gemm_fp16_rcr_...', + '-m=1024', '-n=1024', '-k=1024', + '-json_output=true' +], capture_output=True, text=True) + +# Parse results +data = json.loads(result.stdout) +print(f"Performance: {data['tflops']} TFLOPS") +``` + +### Batch Testing Script + +```bash +#!/bin/bash +SIZES="512 1024 2048 4096" +for size in $SIZES; do + echo "Testing ${size}x${size}x${size}" + ./bin/benchmark_gemm_... -m=$size -n=$size -k=$size \ + -verify=2 -csv_filename=results.csv +done +``` + +## Contributing + +When adding new features or configurations: +1. Update validation logic in `validation_utils.py` +2. Add tests to `test_validation.py` +3. Update configuration examples +4. Document new command-line options + +For more information about the Composable Kernel project, visit the main repository documentation. diff --git a/tile_engine/ops/gemm/benchmark_gemm.cpp b/tile_engine/ops/gemm/benchmark_gemm.cpp deleted file mode 100644 index db2b648437..0000000000 --- a/tile_engine/ops/gemm/benchmark_gemm.cpp +++ /dev/null @@ -1,68 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include - -#include "gemm_profiler.hpp" -#include "benchmark_gemm.hpp" - -void benchmark_gemm(const ck_tile::ArgParser& arg_parser) -{ - GemmProblem gemm_problem{arg_parser.get_int("split_k"), - arg_parser.get_int("m"), - arg_parser.get_int("n"), - arg_parser.get_int("k"), - arg_parser.get_int("stride_a"), - arg_parser.get_int("stride_b"), - arg_parser.get_int("stride_c"), - DataTypeTraits::name, - DataTypeTraits::name, - DataTypeTraits::name, - DataTypeTraits::name, - ALayout::name, - BLayout::name, - CLayout::name, - arg_parser.get_bool("structured_sparsity")}; - - Setting setting{arg_parser.get_int("warmup"), - arg_parser.get_int("repeat"), - arg_parser.get_bool("timer"), - arg_parser.get_int("verify"), - arg_parser.get_int("init"), - arg_parser.get_bool("log"), - arg_parser.get_str("csv_filename"), - arg_parser.get_bool("flush_cache"), - arg_parser.get_int("rotating_count")}; - - auto& profiler = GemmProfiler::instance(setting); - - try - { - auto kernel_func = get_kernel_func_by_trait(arg_parser); - profiler.benchmark(gemm_problem, kernel_func); - profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); - } - catch(const std::exception& e) - { - std::cerr << "Benchmark failed: " << e.what() << std::endl; - } -} - -int main(int argc, char* argv[]) -{ - try - { - auto [result, parser] = create_args(argc, argv); - if(!result) - return EXIT_FAILURE; - benchmark_gemm(parser); - return 0; - } - catch(const std::exception& e) - { - std::cerr << "Error: " << e.what() << "\n"; - return EXIT_FAILURE; - } -} diff --git a/tile_engine/ops/gemm/benchmark_gemm.hpp b/tile_engine/ops/gemm/benchmark_gemm.hpp index ce8a6e8234..0e2619785e 100644 --- a/tile_engine/ops/gemm/benchmark_gemm.hpp +++ b/tile_engine/ops/gemm/benchmark_gemm.hpp @@ -7,8 +7,14 @@ #include #include #include +#include -#include "gemm_host_api.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm_common.hpp" + +// Data types and Layouts are defined by the generated kernel headers +// No hardcoded type definitions here to avoid conflicts enum class Metric { @@ -55,8 +61,9 @@ struct GemmProblem << " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n" << " \"layout_a\":\"" << problem.layout_a_ << "\",\n" << " \"layout_b\":\"" << problem.layout_b_ << "\",\n" - << " \"layout_c\":\"" << problem.layout_c_ << "\"\n" - << " \"structured_sparsity\":\"" << problem.structured_sparsity_ << "\"\n" + << " \"layout_c\":\"" << problem.layout_c_ << "\",\n" + << " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false") + << "\n" << "}"; return os; } @@ -105,9 +112,8 @@ struct KernelInstance friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) { os << "{\n" - << " \"name\": \"" << "{\n" - << obj.name_ << "\n}" << "\",\n" - << " \"problem\": \"" << obj.problem_ << "\",\n" + << " \"name\": \"" << obj.name_ << "\",\n" + << " \"problem\": " << obj.problem_ << ",\n" << " \"perf_result\": " << obj.perf_result_ << "\n" << "}"; return os; @@ -125,6 +131,7 @@ struct Setting std::string csv_filename_; bool flush_cache_; int rotating_count_; + bool json_output_; }; inline std::string get_rocm_version() diff --git a/tile_engine/ops/gemm/benchmark_gemm_single.cpp b/tile_engine/ops/gemm/benchmark_gemm_single.cpp new file mode 100644 index 0000000000..58532ffbe8 --- /dev/null +++ b/tile_engine/ops/gemm/benchmark_gemm_single.cpp @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm_profiler.hpp" +#include "gemm_common.hpp" + +// The kernel header is included via the compile command line with -include flag +// It defines SelectedKernel struct and KERNEL_NAME +// DataTypeTraits are now defined in gemm_common.hpp + +// Create argument parser +inline auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.") + .insert("n", "4096", "The value for n dimension. Default is 4096.") + .insert("k", "2048", "The value for k dimension. Default is 2048.") + .insert("stride_a", "0", "The stride value for tensor A. Default is 0.") + .insert("stride_b", "0", "The stride value for tensor B. Default is 0.") + .insert("stride_c", "0", "The stride value for tensor C. Default is 0.") + .insert("split_k", "1", "The split value for k dimension. Default is 1.") + .insert("verify", + "0", + "The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 " + "for validation on GPU. Default is 0, no validation.") + .insert("log", + "false", + "Whether output kernel instance information or not. Possible values are true or " + "false. Default is false") + .insert( + "warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.") + .insert( + "repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.") + .insert("timer", + "true", + "Whether if the timer is gpu timer or not. Possible values are false or true. " + "Default is true.") + .insert("init", + "0", + "The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 " + "for constant(1). Default is 0, random.") + .insert("flush_cache", + "true", + "To flush cache, possible values are true or false. " + "Default is false.") + .insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.") + .insert("metric", + "0", + "Metric with which to measure kernel performance. Set to 0 for latency, 1 for " + "tflops, or 2 for bandwidth. Default is 0, latency.") + .insert("csv_filename", + "", + "The filename of benchmark result. Default is empty (no CSV output).") + .insert("structured_sparsity", + "false", + "Whether use sparsity kernel or not. Possible values are true or false. Default is " + "false") + .insert("json_output", + "false", + "Whether to output results in JSON format only. Possible values are true or false. " + "Default is " + "false"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser) +{ + // Use DataTypeTraits to get the actual type names from the generated header + // The generated header defines ADataType, BDataType, AccDataType, CDataType + std::string dtype_a = DataTypeTraits::name; + std::string dtype_b = DataTypeTraits::name; + std::string dtype_acc = DataTypeTraits::name; + std::string dtype_c = DataTypeTraits::name; + + // Layout names from the layout types + std::string layout_a = ALayout::name; + std::string layout_b = BLayout::name; + std::string layout_c = CLayout::name; + + // Create GemmProblem struct + GemmProblem gemm_problem{arg_parser.get_int("split_k"), + arg_parser.get_int("m"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("stride_a"), + arg_parser.get_int("stride_b"), + arg_parser.get_int("stride_c"), + dtype_a, + dtype_b, + dtype_acc, + dtype_c, + layout_a, + layout_b, + layout_c, + arg_parser.get_bool("structured_sparsity")}; + + // Create Setting struct + Setting setting{arg_parser.get_int("warmup"), + arg_parser.get_int("repeat"), + arg_parser.get_bool("timer"), + arg_parser.get_int("verify"), + arg_parser.get_int("init"), + arg_parser.get_bool("log"), + arg_parser.get_str("csv_filename"), + arg_parser.get_bool("flush_cache"), + arg_parser.get_int("rotating_count"), + arg_parser.get_bool("json_output")}; + + // Get the profiler instance + auto& profiler = GemmProfiler::instance(setting); + + try + { + // Create a lambda that wraps the kernel launch + auto kernel_func = [](const ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& stream) { + return SelectedKernel::launch(args, stream); + }; + + // Benchmark the kernel + profiler.benchmark(gemm_problem, kernel_func); + + // Select best instance based on metric + profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); + } + catch(const std::exception& e) + { + std::cerr << "Benchmark failed: " << e.what() << std::endl; + } +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, parser] = create_args(argc, argv); + if(!result) + return EXIT_FAILURE; + + benchmark_gemm_single(parser); + return 0; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 392125aa0b..6a87193043 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -170,6 +170,14 @@ warp_tile_supported_combinations = { [16, 16, 128], [32, 32, 64], ], + "fp8_bf8_fp16": [ + [16, 16, 128], + [32, 32, 64], + ], + "bf8_fp8_fp16": [ + [16, 16, 128], + [32, 32, 64], + ], }, } diff --git a/tile_engine/ops/gemm/configs/benchmark.json b/tile_engine/ops/gemm/configs/benchmark.json index def3ca4453..b15b587147 100644 --- a/tile_engine/ops/gemm/configs/benchmark.json +++ b/tile_engine/ops/gemm/configs/benchmark.json @@ -5,20 +5,17 @@ "tile_m": { "max": 256, "min": 64, - "step": 64, - "exclude": [192] + "step": 64 }, "tile_n": { "max": 256, "min": 64, - "step": 64, - "exclude": [192] + "step": 64 }, "tile_k": { "max": 256, "min": 64, - "step": 64, - "exclude": [192] + "step": 64 }, "warp_m": { "values": [ @@ -79,7 +76,8 @@ }, "epilogue": { "values": [ - "cshuffle" + "cshuffle", + "default" ] }, "pad_m": { diff --git a/tile_engine/ops/gemm/configs/default_config.json b/tile_engine/ops/gemm/configs/default_config.json index 5bd51b809a..b245c3167f 100644 --- a/tile_engine/ops/gemm/configs/default_config.json +++ b/tile_engine/ops/gemm/configs/default_config.json @@ -1,105 +1,105 @@ { - "problem": { - }, - "tile_config": { - "tile_m": { - "values": [ - 256 - ] + "problem": { }, - "tile_n": { - "values": [ - 128, - 256 - ] + "tile_config": { + "tile_m": { + "max": 256, + "min": 64, + "step": 64 + }, + "tile_n": { + "max": 256, + "min": 64, + "step": 64 + }, + "tile_k": { + "max": 256, + "min": 64, + "step": 64 + }, + "warp_m": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_n": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 4, + 16, + 32 + ] + }, + "warp_tile_n": { + "values": [ + 16, + 32, + 64 + ] + }, + "warp_tile_k": { + "values": [ + 8, + 16, + 32, + 64, + 128 + ] + } }, - "tile_k": { - "values": [ - 32 - ] - }, - "warp_m": { - "values": [ - 1, - 2, - 4 - ] - }, - "warp_n": { - "values": [ - 1, - 2, - 4 - ] - }, - "warp_k": { - "values": [ - 1 - ] - }, - "warp_tile_m": { - "values": [ - 4, - 16, - 32 - ] - }, - "warp_tile_n": { - "values": [ - 16, - 32, - 64 - ] - }, - "warp_tile_k": { - "values": [ - 8, - 16, - 32, - 64, - 128 - ] + "trait_config": { + "pipeline": { + "values": [ + "compv3", + "compv4", + "mem" + ] + }, + "scheduler": { + "values": [ + "intrawave", + "interwave" + ] + }, + "epilogue": { + "values": [ + "cshuffle", + "default" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + }, + "persistent": { + "values": [ + false, + true + ] + } } - }, - "trait_config": { - "pipeline": { - "values": [ - "compv3", - "compv4", - "mem" - ] - }, - "scheduler": { - "values": [ - "intrawave", - "interwave" - ] - }, - "epilogue": { - "values": [ - "cshuffle", - "default" - ] - }, - "pad_m": { - "values": [ - false - ] - }, - "pad_n": { - "values": [ - false - ] - }, - "pad_k": { - "values": [ - false - ] - }, - "persistent": { - "values": [ - false - ] - } - } } diff --git a/tile_engine/ops/gemm/gemm_benchmark.py b/tile_engine/ops/gemm/gemm_benchmark.py new file mode 100755 index 0000000000..3b0f0e619d --- /dev/null +++ b/tile_engine/ops/gemm/gemm_benchmark.py @@ -0,0 +1,721 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +import sys +import json +import subprocess +import argparse +import csv +import time +from pathlib import Path +from typing import List, Dict, Tuple, Optional + + +class GemmBenchmark: + def __init__(self, build_dir: str, verbose: bool = False): + self.build_dir = Path(build_dir) + self.verbose = verbose + self.results = [] + + def discover_kernels(self) -> List[Path]: + """Find all benchmark_gemm_* executables in the build directory""" + bin_dir = self.build_dir / "bin" + if not bin_dir.exists(): + print(f"Error: Binary directory {bin_dir} does not exist") + return [] + + kernels = list(bin_dir.glob("benchmark_gemm_*")) + if self.verbose: + print(f"Found {len(kernels)} kernel executables") + for k in kernels: + print(f" - {k.name}") + return kernels + + def extract_kernel_info(self, kernel_path: Path) -> Dict[str, str]: + """Extract comprehensive kernel information from filename""" + name = kernel_path.stem + + # Initialize with basic info + info = { + "executable": str(kernel_path), + "name": name, + "data_type": "unknown", + "layout": "unknown", + "pipeline": "unknown", + "scheduler": "unknown", + "epilogue": "unknown", + } + + # Parse the kernel name pattern: + # benchmark_gemm_fp16_rcr_mem_default_intrawave_False_False_False_False_False_256x256x32_2x2x1_4x64x16 + parts = name.split("_") + + if len(parts) >= 3: + # Extract data type (3rd part after benchmark_gemm_) + info["data_type"] = parts[2] if len(parts) > 2 else "unknown" + + # Extract layout (4th part) + info["layout"] = parts[3] if len(parts) > 3 else "unknown" + + # Extract pipeline (5th part) + info["pipeline"] = parts[4] if len(parts) > 4 else "unknown" + + # Extract epilogue (6th part) + info["epilogue"] = parts[5] if len(parts) > 5 else "unknown" + + # Extract scheduler (7th part) + info["scheduler"] = parts[6] if len(parts) > 6 else "unknown" + + # Extract detailed configuration from the end of the name + config_info = self.parse_detailed_config(name) + info.update(config_info) + + # Generate config ID + info["config_id"] = self.generate_config_id(info) + + return info + + def parse_detailed_config(self, kernel_name: str) -> Dict: + """Parse detailed configuration from kernel name""" + config = { + "tile_sizes": {"tile_m": 0, "tile_n": 0, "tile_k": 0}, + "warp_config": {"warp_m": 0, "warp_n": 0, "warp_k": 0}, + "warp_tile": {"warp_tile_m": 0, "warp_tile_n": 0, "warp_tile_k": 0}, + "optimization_flags": { + "pad_m": False, + "pad_n": False, + "pad_k": False, + "persistent": False, + }, + } + + # Split by underscore and look for patterns + parts = kernel_name.split("_") + + # Look for boolean flags (sequence of True/False values) + bool_sequence = [] + for i, part in enumerate(parts): + if part in ["True", "False"]: + bool_sequence.append(part == "True") + # Continue collecting consecutive boolean values + j = i + 1 + while j < len(parts) and parts[j] in ["True", "False"]: + bool_sequence.append(parts[j] == "True") + j += 1 + break + + # Assign boolean flags if we found them + # Order: pad_m, pad_n, pad_k, persistent (4 flags total) + if len(bool_sequence) >= 4: + config["optimization_flags"]["pad_m"] = bool_sequence[0] + config["optimization_flags"]["pad_n"] = bool_sequence[1] + config["optimization_flags"]["pad_k"] = bool_sequence[2] + config["optimization_flags"]["persistent"] = bool_sequence[3] + + # Look for tile size patterns (e.g., 256x256x32_2x2x1_4x64x16) + # The pattern is: tile_sizes_warp_config_warp_tile + dimension_groups = [] + for part in parts: + if "x" in part and len(part.split("x")) == 3: + try: + dims = [int(x) for x in part.split("x")] + if all(d > 0 for d in dims): + dimension_groups.append(dims) + except ValueError: + continue + + # Assign dimensions based on order and magnitude + if len(dimension_groups) >= 3: + # Sort by magnitude to identify: largest=tile_sizes, smallest=warp_config, middle=warp_tile + sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True) + + # Largest dimensions = tile sizes + config["tile_sizes"]["tile_m"] = sorted_groups[0][0] + config["tile_sizes"]["tile_n"] = sorted_groups[0][1] + config["tile_sizes"]["tile_k"] = sorted_groups[0][2] + + # Smallest dimensions = warp config + config["warp_config"]["warp_m"] = sorted_groups[2][0] + config["warp_config"]["warp_n"] = sorted_groups[2][1] + config["warp_config"]["warp_k"] = sorted_groups[2][2] + + # Middle dimensions = warp tile + config["warp_tile"]["warp_tile_m"] = sorted_groups[1][0] + config["warp_tile"]["warp_tile_n"] = sorted_groups[1][1] + config["warp_tile"]["warp_tile_k"] = sorted_groups[1][2] + elif len(dimension_groups) == 2: + # If only 2 groups, assign based on magnitude + sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True) + + # Larger = tile sizes + config["tile_sizes"]["tile_m"] = sorted_groups[0][0] + config["tile_sizes"]["tile_n"] = sorted_groups[0][1] + config["tile_sizes"]["tile_k"] = sorted_groups[0][2] + + # Smaller = warp config + config["warp_config"]["warp_m"] = sorted_groups[1][0] + config["warp_config"]["warp_n"] = sorted_groups[1][1] + config["warp_config"]["warp_k"] = sorted_groups[1][2] + elif len(dimension_groups) == 1: + # Only one group - assume it's tile sizes + config["tile_sizes"]["tile_m"] = dimension_groups[0][0] + config["tile_sizes"]["tile_n"] = dimension_groups[0][1] + config["tile_sizes"]["tile_k"] = dimension_groups[0][2] + + return config + + def generate_config_id(self, info: Dict) -> str: + """Generate a compact config ID from kernel info""" + # Create a compact identifier + parts = [ + info.get("data_type", "unk"), + info.get("layout", "unk"), + info.get("pipeline", "unk"), + info.get("scheduler", "unk"), + ] + + # Add tile configuration if available + tile_sizes = info.get("tile_sizes", {}) + if tile_sizes.get("tile_m", 0) > 0: + tile_str = ( + f"{tile_sizes['tile_m']}x{tile_sizes['tile_n']}x{tile_sizes['tile_k']}" + ) + parts.append(tile_str) + + # Add warp config if available + warp_config = info.get("warp_config", {}) + if warp_config.get("warp_m", 0) > 0: + warp_str = f"w{warp_config['warp_m']}x{warp_config['warp_n']}x{warp_config['warp_k']}" + parts.append(warp_str) + + # Add warp tile if available + warp_tile = info.get("warp_tile", {}) + if warp_tile.get("warp_tile_m", 0) > 0: + warp_tile_str = f"wt{warp_tile['warp_tile_m']}x{warp_tile['warp_tile_n']}x{warp_tile['warp_tile_k']}" + parts.append(warp_tile_str) + + return "_".join(parts) + + def run_kernel(self, kernel_path: Path, params: Dict[str, str]) -> Optional[Dict]: + """Run a single kernel with given parameters and save output to individual JSON file""" + # Create results directory + results_dir = self.build_dir / "results" + results_dir.mkdir(exist_ok=True) + + # Generate unique JSON filename for this kernel + json_file = results_dir / f"{kernel_path.stem}.json" + + cmd = [str(kernel_path)] + + # Add parameters + for key, value in params.items(): + cmd.append(f"-{key}={value}") + + # Add JSON output flag for clean JSON output + cmd.append("-json_output=true") + + if self.verbose: + print(f"Running: {' '.join(cmd)}") + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + + if result.returncode != 0: + print(f"Error running {kernel_path.name}: {result.stderr}") + return None + + # Save raw output to individual JSON file + output = result.stdout.strip() + if output: + with open(json_file, "w") as f: + f.write(output) + + # Parse the JSON file + return self.parse_json_file(json_file) + else: + print(f"No output from {kernel_path.name}") + return None + + except subprocess.TimeoutExpired: + print(f"Timeout running {kernel_path.name}") + return None + except Exception as e: + print(f"Error running {kernel_path.name}: {e}") + return None + + def parse_json_file(self, json_file: Path) -> Optional[Dict]: + """Parse JSON data from individual kernel output file""" + try: + with open(json_file, "r") as f: + content = f.read().strip() + + # Parse the JSON directly since executables produce clean JSON + data = json.loads(content) + + # Return the complete JSON data as-is, just add some convenience fields + result = data.copy() + if "perf_result" in data: + perf = data["perf_result"] + # Add convenience fields for backward compatibility + result["time_ms"] = perf.get("latency(ms)", 0) + result["tflops"] = perf.get("tflops(TFlops)", 0) + result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0) + + return result + + except json.JSONDecodeError as e: + if self.verbose: + print(f"Failed to parse JSON from {json_file}: {e}") + return None + except Exception as e: + if self.verbose: + print(f"Error reading JSON file {json_file}: {e}") + return None + + def parse_benchmark_output(self, output: str) -> Optional[Dict]: + """Parse the benchmark output format - extract JSON directly""" + try: + # Find JSON block between asterisk markers + lines = output.split("\n") + json_start = -1 + json_end = -1 + + for i, line in enumerate(lines): + if line.strip().startswith("{"): + json_start = i + elif line.strip().endswith("}") and json_start != -1: + json_end = i + break + + if json_start != -1 and json_end != -1: + json_text = "\n".join(lines[json_start : json_end + 1]) + data = json.loads(json_text) + + # Return the complete JSON data as-is, just add some convenience fields + result = data.copy() + if "perf_result" in data: + perf = data["perf_result"] + # Add convenience fields for backward compatibility + result["time_ms"] = perf.get("latency(ms)", 0) + result["tflops"] = perf.get("tflops(TFlops)", 0) + result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0) + + return result + + return None + + except json.JSONDecodeError as e: + if self.verbose: + print(f"Failed to parse JSON: {e}") + print(f"Output was: {output[:200]}...") + return None + except Exception as e: + if self.verbose: + print(f"Error parsing output: {e}") + return None + + def benchmark_problem_size( + self, + kernels: List[Path], + m: int, + n: int, + k: int, + split_k: int = 1, + verify: int = 0, + warmup: int = 50, + repeat: int = 100, + flush_cache: bool = True, + rotating_count: int = 1000, + ) -> List[Dict]: + """Benchmark all kernels for a specific problem size""" + results = [] + + params = { + "m": m, + "n": n, + "k": k, + "split_k": split_k, + "verify": verify, + "warmup": warmup, + "repeat": repeat, + "flush_cache": str(flush_cache).lower(), + "rotating_count": rotating_count, + } + + print(f"\nBenchmarking M={m}, N={n}, K={k}, split_k={split_k}") + + for kernel_path in kernels: + kernel_info = self.extract_kernel_info(kernel_path) + result = self.run_kernel(kernel_path, params) + + if result: + # Create new structured result format + structured_result = { + "name": kernel_info["name"], # Add name field for compatibility + "config_id": kernel_info["config_id"], + "problem": result.get("problem", {}), + "perf_result": result.get("perf_result", {}), + "config": { + "data_type": kernel_info["data_type"], + "layout": kernel_info["layout"], + "pipeline": kernel_info["pipeline"], + "scheduler": kernel_info["scheduler"], + "epilogue": kernel_info["epilogue"], + "tile_sizes": kernel_info.get("tile_sizes", {}), + "warp_config": kernel_info.get("warp_config", {}), + "warp_tile": kernel_info.get("warp_tile", {}), + "optimization_flags": kernel_info.get("optimization_flags", {}), + }, + "executable": kernel_info["executable"], + # Keep backward compatibility fields + "time_ms": result.get("time_ms", 0), + "tflops": result.get("tflops", 0), + "bandwidth_gb_s": result.get("bandwidth_gb_s", 0), + } + + results.append(structured_result) + + if self.verbose: + print( + f" {kernel_info['config_id']}: {structured_result['tflops']:.2f} TFLOPS, {structured_result['bandwidth_gb_s']:.2f} GB/s, {structured_result['time_ms']:.2f}ms" + ) + + return results + + def find_best_kernel( + self, results: List[Dict], metric: str = "tflops" + ) -> Optional[Dict]: + """Find the best performing kernel based on metric""" + if not results: + return None + + if metric == "tflops": + return max(results, key=lambda x: x.get("tflops", 0)) + elif metric == "time_ms": + return min(results, key=lambda x: x.get("time_ms", float("inf"))) + elif metric == "bandwidth_gb_s": + return max(results, key=lambda x: x.get("bandwidth_gb_s", 0)) + else: + raise ValueError(f"Unknown metric: {metric}") + + def benchmark_sweep( + self, + problem_sizes: List[Tuple[int, int, int]], + split_k_values: List[int] = [1], + verify: bool = False, + warmup: int = 50, + repeat: int = 100, + flush_cache: bool = True, + rotating_count: int = 1000, + ) -> Dict: + """Run comprehensive benchmark sweep""" + kernels = self.discover_kernels() + if not kernels: + print("No kernels found!") + return {} + + all_results = [] + best_kernels = {} + + for m, n, k in problem_sizes: + for split_k in split_k_values: + results = self.benchmark_problem_size( + kernels, + m, + n, + k, + split_k, + verify=2 if verify else 0, + warmup=warmup, + repeat=repeat, + flush_cache=flush_cache, + rotating_count=rotating_count, + ) + + all_results.extend(results) + + # Find best kernel for this configuration + best = self.find_best_kernel(results) + if best: + key = f"m{m}_n{n}_k{k}_splitk{split_k}" + best_kernels[key] = best + print( + f"Best for {key}: {best['name']} ({best['tflops']:.2f} TFLOPS, {best['bandwidth_gb_s']:.2f} GB/s, {best['time_ms']:.2f}ms)" + ) + + self.results = all_results + return best_kernels + + def export_csv(self, filename: str): + """Export all results to CSV""" + if not self.results: + print("No results to export") + return + + # Get all unique keys from results + all_keys = set() + for result in self.results: + all_keys.update(result.keys()) + + # Sort keys for consistent output + fieldnames = sorted(all_keys) + + with open(filename, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(self.results) + + print(f"Results exported to {filename}") + + def export_best_kernels(self, best_kernels: Dict, filename: str): + """Export best kernel selections to file""" + with open(filename, "w") as f: + f.write("# Best kernel selections\n") + f.write( + "# Format: problem_size -> kernel_name (TFLOPS, bandwidth, latency)\n\n" + ) + + for key, kernel in sorted(best_kernels.items()): + f.write( + f"{key}: {kernel['name']} ({kernel['tflops']:.2f} TFLOPS, {kernel['bandwidth_gb_s']:.2f} GB/s, {kernel['time_ms']:.2f}ms)\n" + ) + + print(f"Best kernels exported to {filename}") + + def export_json(self, filename: str, best_kernels: Dict = None): + """Export all results and best kernels to JSON with comprehensive metadata""" + from datetime import datetime + + # Calculate comprehensive summary statistics for all metrics + successful_results = [r for r in self.results if r.get("tflops", 0) > 0] + + tflops_values = [r.get("tflops", 0) for r in successful_results] + bandwidth_values = [r.get("bandwidth_gb_s", 0) for r in successful_results] + latency_values = [ + r.get("time_ms", 0) for r in successful_results if r.get("time_ms", 0) > 0 + ] + + # Performance breakdown by kernel type + pipeline_stats = {} + scheduler_stats = {} + data_type_stats = {} + + for result in successful_results: + # Get config info from the new structure + config = result.get("config", {}) + + # Pipeline statistics + pipeline = config.get("pipeline", "unknown") + if pipeline not in pipeline_stats: + pipeline_stats[pipeline] = { + "count": 0, + "avg_tflops": 0, + "best_tflops": 0, + } + pipeline_stats[pipeline]["count"] += 1 + pipeline_stats[pipeline]["best_tflops"] = max( + pipeline_stats[pipeline]["best_tflops"], result.get("tflops", 0) + ) + + # Scheduler statistics + scheduler = config.get("scheduler", "unknown") + if scheduler not in scheduler_stats: + scheduler_stats[scheduler] = { + "count": 0, + "avg_tflops": 0, + "best_tflops": 0, + } + scheduler_stats[scheduler]["count"] += 1 + scheduler_stats[scheduler]["best_tflops"] = max( + scheduler_stats[scheduler]["best_tflops"], result.get("tflops", 0) + ) + + # Data type statistics + data_type = config.get("data_type", "unknown") + if data_type not in data_type_stats: + data_type_stats[data_type] = { + "count": 0, + "avg_tflops": 0, + "best_tflops": 0, + } + data_type_stats[data_type]["count"] += 1 + data_type_stats[data_type]["best_tflops"] = max( + data_type_stats[data_type]["best_tflops"], result.get("tflops", 0) + ) + + # Calculate averages for breakdown stats + for stats_dict, field_name in [ + (pipeline_stats, "pipeline"), + (scheduler_stats, "scheduler"), + (data_type_stats, "data_type"), + ]: + for key in stats_dict: + relevant_results = [ + r + for r in successful_results + if r.get("config", {}).get(field_name, "unknown") == key + ] + if relevant_results: + stats_dict[key]["avg_tflops"] = sum( + r.get("tflops", 0) for r in relevant_results + ) / len(relevant_results) + + output_data = { + "benchmark_metadata": { + "timestamp": datetime.now().isoformat(), + "total_kernels_tested": len(self.results), + "unique_kernels": len( + set(r.get("name", "unknown") for r in self.results) + ), + "successful_runs": len(successful_results), + "failed_runs": len(self.results) - len(successful_results), + }, + "performance_summary": { + "tflops_stats": { + "best": max(tflops_values, default=0), + "average": sum(tflops_values) / len(tflops_values) + if tflops_values + else 0, + "min": min(tflops_values, default=0), + "median": sorted(tflops_values)[len(tflops_values) // 2] + if tflops_values + else 0, + }, + "bandwidth_stats": { + "best_gb_s": max(bandwidth_values, default=0), + "average_gb_s": sum(bandwidth_values) / len(bandwidth_values) + if bandwidth_values + else 0, + "min_gb_s": min(bandwidth_values, default=0), + "median_gb_s": sorted(bandwidth_values)[len(bandwidth_values) // 2] + if bandwidth_values + else 0, + }, + "latency_stats": { + "best_ms": min(latency_values, default=0), + "average_ms": sum(latency_values) / len(latency_values) + if latency_values + else 0, + "max_ms": max(latency_values, default=0), + "median_ms": sorted(latency_values)[len(latency_values) // 2] + if latency_values + else 0, + }, + "kernel_type_breakdown": { + "by_pipeline": pipeline_stats, + "by_scheduler": scheduler_stats, + "by_data_type": data_type_stats, + }, + "total_problem_configurations": len(best_kernels) + if best_kernels + else 0, + }, + "kernel_results": self.results, + "best_kernels_by_problem": best_kernels or {}, + } + + with open(filename, "w") as f: + json.dump(output_data, f, indent=2) + + print(f"JSON results exported to {filename}") + print(f" - Total kernels: {len(self.results)}") + print(f" - Successful runs: {len(successful_results)}") + print(f" - Best TFLOPS: {max(tflops_values, default=0):.2f}") + print(f" - Best bandwidth: {max(bandwidth_values, default=0):.2f} GB/s") + print(f" - Best latency: {min(latency_values, default=0):.2f}ms") + + +def main(): + parser = argparse.ArgumentParser(description="GEMM Kernel Benchmarking Tool") + parser.add_argument( + "build_dir", help="Build directory containing kernel executables" + ) + parser.add_argument( + "--problem-sizes", + nargs="+", + default=["1024,1024,1024", "2048,2048,2048", "4096,4096,4096"], + help="Problem sizes as M,N,K tuples", + ) + parser.add_argument( + "--split-k", nargs="+", type=int, default=[1], help="Split-K values to test" + ) + parser.add_argument("--verify", action="store_true", help="Enable verification") + parser.add_argument( + "--csv", default="gemm_benchmark_results.csv", help="CSV output filename" + ) + parser.add_argument( + "--best", default="best_kernels.txt", help="Best kernels output filename" + ) + parser.add_argument("--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--warmup", + type=int, + default=50, + help="Number of warmup iterations (default: 50)", + ) + parser.add_argument( + "--repeat", + type=int, + default=100, + help="Number of benchmark iterations (default: 100)", + ) + parser.add_argument( + "--flush-cache", + action="store_true", + default=True, + help="Enable cache flushing (default: True)", + ) + parser.add_argument( + "--rotating-count", + type=int, + default=1000, + help="Number of iterations to rotate cache (default: 1000)", + ) + parser.add_argument("--json", help="JSON output filename (optional)") + + args = parser.parse_args() + + # Parse problem sizes + problem_sizes = [] + for size_str in args.problem_sizes: + try: + m, n, k = map(int, size_str.split(",")) + problem_sizes.append((m, n, k)) + except ValueError: + print(f"Invalid problem size: {size_str}") + return 1 + + # Create benchmark instance + benchmark = GemmBenchmark(args.build_dir, verbose=args.verbose) + + # Run benchmark sweep + print("Starting GEMM kernel benchmark sweep...") + start_time = time.time() + + best_kernels = benchmark.benchmark_sweep( + problem_sizes=problem_sizes, + split_k_values=args.split_k, + verify=args.verify, + warmup=args.warmup, + repeat=args.repeat, + flush_cache=args.flush_cache, + rotating_count=args.rotating_count, + ) + + elapsed_time = time.time() - start_time + print(f"\nBenchmark completed in {elapsed_time:.2f} seconds") + + # Export results + benchmark.export_csv(args.csv) + benchmark.export_best_kernels(best_kernels, args.best) + + # Export JSON if requested + if args.json: + benchmark.export_json(args.json, best_kernels) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tile_engine/ops/gemm/gemm_common.hpp b/tile_engine/ops/gemm/gemm_common.hpp new file mode 100644 index 0000000000..5188915f1a --- /dev/null +++ b/tile_engine/ops/gemm/gemm_common.hpp @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" + +// DataTypeTraits for all supported types +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +// Helper function to determine if a layout is row-major +template +constexpr auto is_row_major(Layout) +{ + return ck_tile::bool_constant>{}; +} + +// Permutation function for pk_int4_t +template +void permute_vectors_i4x4_b(Tensor& tensor) +{ + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int8_t input[8]; + + for(int k = 0; k < 4; k++) + { + int8_t i4x2 = tensor(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int8_t hi = input[2]; + int8_t lo = input[0]; + int8_t i4x2 = (hi << 4) | lo; + tensor(j + 0, i) = i4x2; + } + + { + int8_t hi = input[6]; + int8_t lo = input[4]; + int8_t i4x2 = (hi << 4) | lo; + tensor(j + 2, i) = i4x2; + } + + { + int8_t hi = input[3]; + int8_t lo = input[1]; + int8_t i4x2 = (hi << 4) | lo; + tensor(j + 4, i) = i4x2; + } + + { + int8_t hi = input[7]; + int8_t lo = input[5]; + int8_t i4x2 = (hi << 4) | lo; + tensor(j + 6, i) = i4x2; + } + } + } +} + +// Structure to hold kernel traits for dispatcher +struct KernelTraits +{ + std::string pipeline; // compv3, compv4, mem + std::string scheduler; // intrawave, interwave + std::string epilogue; // cshuffle, default + bool pad_m; + bool pad_n; + bool pad_k; + bool persistent; + + // Constructor with defaults + KernelTraits() + : pipeline("compv3"), + scheduler("intrawave"), + epilogue("cshuffle"), + pad_m(false), + pad_n(false), + pad_k(false), + persistent(false) + { + } +}; + +// Helper to extract traits from kernel name +inline KernelTraits extract_traits_from_name(const std::string& kernel_name) +{ + KernelTraits traits; + + // Extract pipeline + if(kernel_name.find("compv3") != std::string::npos) + { + traits.pipeline = "compv3"; + } + else if(kernel_name.find("compv4") != std::string::npos) + { + traits.pipeline = "compv4"; + } + else if(kernel_name.find("mem") != std::string::npos) + { + traits.pipeline = "mem"; + } + + // Extract scheduler + if(kernel_name.find("interwave") != std::string::npos) + { + traits.scheduler = "interwave"; + } + else + { + traits.scheduler = "intrawave"; + } + + // Extract epilogue + if(kernel_name.find("default") != std::string::npos && + kernel_name.find("default_") == std::string::npos) + { + traits.epilogue = "default"; + } + else + { + traits.epilogue = "cshuffle"; + } + + // Padding flags would need to be extracted from the kernel configuration + // For now, we'll leave them as false + + return traits; +} diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp deleted file mode 100644 index f28f5dd29c..0000000000 --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ /dev/null @@ -1,223 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include - -#include "ck_tile/host.hpp" -#include "gemm_dispatcher.hpp" -#include "gemm_common.hpp" - -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - -template -static constexpr inline auto is_row_major(Layout layout_) -{ - return ck_tile::bool_constant, - ck_tile::tensor_layout::gemm::RowMajor>>{}; -} - -inline auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.") - .insert("n", "4096", "The value for n dimension. Default is 4096.") - .insert("k", "2048", "The value for k dimension. Default is 2048.") - .insert("stride_a", "0", "The stride value for tensor A. Default is 0.") - .insert("stride_b", "0", "The stride value for tensor B. Default is 0.") - .insert("stride_c", "0", "The stride value for tensor C Default is 0.") - .insert("split_k", "1", "The split value for k dimension. Default is 1.") - .insert("verify", - "2", - "The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 " - "for validation on GPU. Default is 2, validation on GPU.") - .insert("log", - "false", - "Wether output kernel instance information or not. Possible values are true or " - "false. Default is false") - .insert( - "warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.") - .insert( - "repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.") - .insert("timer", - "true", - "Whether if the timer is gpu timer or not. Possible values are false or true. " - "Default is true.") - .insert("init", - "0", - "The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 " - "for constant(1). Default is 0, random.") - .insert("flush_cache", - "false", - "To flush cache, possible values are true or false. " - "Default is false.") - .insert("rotating_count", "5", "number of iterations to rotate the cache. default is 5.") - .insert("metric", - "0", - "Metric with which to measure kernel performance. Set to 0 for latency, 1 for " - "tflops, or 2 for bandwidth. Default is 0, latency.") - .insert("csv_filename", - "gemm_kernel", - "The filename of benchmark result. Default is gemm_kernel.") - .insert("structured_sparsity", - "false", - "Whether use sparsity kernel or not. Possible values are true or false. Default is " - "false") - .insert( - "pipeline", - "compv3", - "The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.") - .insert("scheduler", - "intrawave", - "The type of pipeline. Possible values are compv3, compv4 or mem. Default is " - "compv3.") - .insert( - "epilogue", - "cshuffle", - "The type of epilogue. Possible values are cshuffle or default. Default is csshuffle.") - .insert("pad_m", - "false", - "Whether pad or not in m direction. Possible values are true or false. Default is " - "false.") - .insert("pad_n", - "false", - "Whether pad or not in n direction. Possible values are true or false. Default is " - "false.") - .insert("pad_k", - "false", - "Whether pad or not in k direction. Possible values are true or false. Default is " - "false.") - .insert("persistent", "false", "Whether to use persistent kernel. Default is false."); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -template -void permute_vectors_i4x4_b(Tensor& tensor) -{ - const ck_tile::index_t K = tensor.get_length(0); - const ck_tile::index_t N = tensor.get_length(1); - // vector pk_i4x4 permute - for(int i = 0; i < N; i++) - { - for(int j = 0; j < K; j += 8) - { - int8_t input[8]; - - for(int k = 0; k < 4; k++) - { - int8_t i4x2 = tensor(j + k * 2, i).data; - input[k * 2 + 0] = (i4x2 >> 4) & 0xf; - input[k * 2 + 1] = (i4x2 >> 0) & 0xf; - } - - // permute 01234567->20643175 - { - int8_t hi = input[2]; - int8_t lo = input[0]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 0, i) = i4x2; - } - - { - int8_t hi = input[6]; - int8_t lo = input[4]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 2, i) = i4x2; - } - - { - int8_t hi = input[3]; - int8_t lo = input[1]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 4, i) = i4x2; - } - - { - int8_t hi = input[7]; - int8_t lo = input[5]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 6, i) = i4x2; - } - } - } -} - -auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser) -{ - KernelTraits trait; - trait.pipeline = arg_parser.get_str("pipeline"); - trait.scheduler = arg_parser.get_str("scheduler"); - trait.epilogue = arg_parser.get_str("epilogue"); - trait.pad_m = arg_parser.get_bool("pad_m"); - trait.pad_n = arg_parser.get_bool("pad_n"); - trait.pad_k = arg_parser.get_bool("pad_k"); - trait.persistent = arg_parser.get_bool("persistent"); - - bool structured_sparsity = arg_parser.get_bool("structured_sparsity"); - - return GemmDispatcher::dispatch(structured_sparsity, trait); -} diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py old mode 100755 new mode 100644 index 7def4e2691..d679be7b84 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -1,361 +1,597 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -# -*- coding: utf-8 -*- - -""" -generate kernel instances to speed up compilation -""" +#!/usr/bin/env python +import os +import json import argparse import itertools +import multiprocessing +import concurrent.futures from pathlib import Path -from typing import List, Optional -from json_config import GemmConfig, RangeConfigParam -from codegen_utils import ( - DATA_TYPE_MAP, - LAYOUT_MAP, - PIPELINE_MAP, - SCHEDULER_MAP, - EPILOGUE_MAP, - BOOL_MAP, - warp_tile_supported_combinations, - trait_unsupported_combinations, - element_size, - get_gpu_name_by_id, -) import logging +from validation_utils import is_tile_config_valid, is_trait_combination_valid logging.basicConfig(level=logging.INFO) -class GemmCodeGenerator: - """GEMM (General Matrix Multiplication) code generator.""" +class GemmKernelBuilder: + def __init__(self, working_path, datatype, layout, config_json=None): + self.working_path = Path(working_path) + self.datatype = datatype + self.layout = layout + self.config_json = config_json - def __init__( - self, output_dir: str, user_provided_config: Optional[GemmConfig] = None - ): - self.output_dir = Path(output_dir) - self.output_dir.mkdir(parents=True, exist_ok=True) + # Create working directory if it doesn't exist + self.working_path.mkdir(parents=True, exist_ok=True) - if user_provided_config is not None: - self.config = user_provided_config + # Load configuration + if config_json and os.path.exists(config_json): + with open(config_json, "r") as f: + self.config = json.load(f) else: - config_path = ( - Path(__file__).resolve().parent / "configs" / "default_config.json" + self.config = self._get_default_config() + + def _get_default_config(self): + """Return default configuration if no config file is provided""" + # Define base tile configurations that work for all layouts + base_fp16_configs = [ + { + "tile_m": 256, + "tile_n": 256, + "tile_k": 32, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 32, + }, + { + "tile_m": 256, + "tile_n": 128, + "tile_k": 32, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 16, + }, + ] + + base_fp8_configs = [ + { + "tile_m": 256, + "tile_n": 256, + "tile_k": 32, + "warp_m": 4, + "warp_n": 1, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 32, + }, + { + "tile_m": 256, + "tile_n": 128, + "tile_k": 32, + "warp_m": 1, + "warp_n": 4, + "warp_k": 1, + "warp_tile_m": 16, + "warp_tile_n": 16, + "warp_tile_k": 32, + }, + ] + + # Create configurations for all supported layouts + all_layouts = ["rcr", "rrr", "ccr", "crr"] + tile_configs = {} + + for datatype, base_configs in [ + ("fp16", base_fp16_configs), + ("fp8", base_fp8_configs), + ]: + tile_configs[datatype] = {} + for layout in all_layouts: + tile_configs[datatype][layout] = base_configs + + return { + "tile_configs": tile_configs, + "traits": { + "pipelines": ["mem", "compv3", "compv4"], + "epilogues": ["default", "cshuffle"], + "schedulers": ["intrawave", "interwave"], + }, + "structured_sparsity": ["false"], + "padding": {"pad_m": ["false"], "pad_n": ["false"], "pad_k": ["false"]}, + "persistent": ["false"], + } + + def _get_tile_configs(self, fast_mode=False): + """Get tile configurations for the current datatype and layout""" + if "tile_configs" in self.config: + # Old format + return ( + self.config["tile_configs"].get(self.datatype, {}).get(self.layout, []) ) - self.config = GemmConfig.from_json(config_path) + elif "tile_config" in self.config: + # New format - generate combinations from individual parameter values + tile_config = self.config["tile_config"] - self.valid_trait_names: List[str] = [] - self.valid_trait_tile_combinations: map[str, list[tuple[int]]] = {} + # Get all possible values for each parameter + tile_m_values = tile_config.get("tile_m", {}).get("values", [256]) + tile_n_values = tile_config.get("tile_n", {}).get("values", [256]) + tile_k_values = tile_config.get("tile_k", {}).get("values", [32]) + warp_m_values = tile_config.get("warp_m", {}).get("values", [2]) + warp_n_values = tile_config.get("warp_n", {}).get("values", [2]) + warp_k_values = tile_config.get("warp_k", {}).get("values", [1]) + warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32]) + warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32]) + warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32]) - def list_all_trait_names(self): - """List all possible kernel trait names into file.""" - w_p = Path(self.output_dir) - file_path = w_p / "gemm_instance_blobs.txt" - self._generate_all_traits() - self._get_valid_trait_tile_combinations() - file_range_map = {} - # Write all file paths to the header file - files_listed = 0 - with file_path.open("w") as f: - # Core files - core_files = [ - "gemm_common.hpp", - "gemm_instances.hpp", - "gemm_dispatcher.hpp", - ] - for core_file in core_files: - f.write(str(w_p / core_file) + "\n") - files_listed += 1 - - # Trait header files - for trait in self.valid_trait_names: - trait_file = f"gemm_{trait}.hpp" - f.write(str(w_p / trait_file) + "\n") - files_listed += 1 - file_name = set() - # Instance source files - for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): - start_idx = files_listed - for tile in tile_valid_params: - for ( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - _, - _, - _, - ) in tile: - instance_name = f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp" - - if instance_name not in file_name: - file_name.add(instance_name) - f.write(str(w_p / instance_name) + "\n") - files_listed += 1 - - file_range_map[trait] = (start_idx, files_listed) - - file_path = w_p / "gemm_instance_blobs_range.txt" - with file_path.open("w") as f: - for name, ranges in file_range_map.items(): - s, l = ranges - f.write(name + " " + f"{s}" + " " + f"{l}" + "\n") - - def _generate_all_traits(self): - """Generate all possible kernel traits names.""" - params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k", "persistent"] - - # Generate all unique_combinations - _unique = set( - itertools.product( - *[getattr(self.config.trait_config, param).values for param in params] - ) - ) - - for combo in _unique: - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = combo - current_combination = (pipeline, epilogue, scheduler) - - if current_combination not in trait_unsupported_combinations: - trait_name = ( - f"{pipeline}_{epilogue}_{scheduler}_" - f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}_" - f"{BOOL_MAP(persistent)}" - ) - self.valid_trait_names.append(trait_name) - else: - logging.debug(f"Invalid combination: {pipeline}-{epilogue}-{scheduler}") - - def generate_all_instance_files(self): - """Generate all kernel instances files.""" - self._generate_common_header_file() - self._generate_all_trait_files() - self._generate_dispatcher_file() - - def _generate_common_header_file(self): - """Generate common header file with datatypes and layout.""" - - # Determine appropriate accumulation type based on input types - a_type = self.config.problem.datatype_map["matrix_a"] - b_type = self.config.problem.datatype_map["matrix_b"] - c_type = self.config.problem.datatype_map["matrix_c"] - - if a_type in ["int8", "int4"] and b_type in ["int8", "int4"]: - acc_type = "ck_tile::int32_t" + # Generate all combinations + configs = [] + for tile_m in tile_m_values: + for tile_n in tile_n_values: + for tile_k in tile_k_values: + for warp_m in warp_m_values: + for warp_n in warp_n_values: + for warp_k in warp_k_values: + for warp_tile_m in warp_tile_m_values: + for warp_tile_n in warp_tile_n_values: + for warp_tile_k in warp_tile_k_values: + # Validate configuration + if self._validate_tile_config( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + fast_mode=fast_mode, + ): + configs.append( + { + "tile_m": tile_m, + "tile_n": tile_n, + "tile_k": tile_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "warp_tile_m": warp_tile_m, + "warp_tile_n": warp_tile_n, + "warp_tile_k": warp_tile_k, + } + ) + return configs else: - acc_type = "float" + # Fallback to default + return [] - content = f"""// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" - -// Data types -using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_a"]]}; -using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_b"]]}; -using AccDataType = {acc_type}; -using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_c"]]}; - -// Layout configurations -using ALayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_a"]]}; -using BLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_b"]]}; -using CLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_c"]]}; -""" - - (self.output_dir / "gemm_common.hpp").write_text(content) - - def _generate_all_trait_files(self): - """Generate all kernel traits into files.""" - if not self.valid_trait_names: - self._generate_all_traits() - self._get_valid_trait_tile_combinations() - for trait in self.valid_trait_names: - self._generate_trait_file(trait) - self._generate_instantiation_source_files() - self._generate_common_instance_header_file() - - def _generate_trait_file(self, trait: str): - """Generate a trait with all tile/warp combinations.""" - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = trait.split("_") - filename = f"gemm_{trait}.hpp" - - content = f"""// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "gemm_common.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/host.hpp" - -namespace {trait} {{ -""" - # Add template struct with configuration - content += self._generate_kernel_struct( - pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent) - - content += f"\n}} // namespace {trait}\n" - (self.output_dir / filename).write_text(content) - - def _generate_kernel_struct( + def _validate_tile_config( self, - pipeline: str, - epilogue: str, - scheduler: str, - pad_m: str, - pad_n: str, - pad_k: str, - persistent: str, - ) -> str: - """Generate the code block of kernel struct""" - return f""" + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + pipeline="mem", # Default pipeline for validation + fast_mode=False, # Add fast mode option + ): + """Validate that tile configuration is reasonable""" + if fast_mode: + # Fast validation for listing - only basic sanity checks + if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: + return False + if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: + return False + if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: + return False -template -struct GemmKernel {{ - static constexpr bool kPadM = {pad_m}; - static constexpr bool kPadN = {pad_n}; - static constexpr bool kPadK = {pad_k}; - static constexpr bool kPersistent = {persistent}; + # Basic divisibility check + if tile_m % (warp_m * warp_tile_m) != 0: + return False + if tile_n % (warp_n * warp_tile_n) != 0: + return False + if tile_k % (warp_k * warp_tile_k) != 0: + return False - static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ - static constexpr bool permuteA = false; - static constexpr bool permuteB = false; - static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"}; - static constexpr bool TransposeC = false; + return True + else: + # Full validation for generation + # Determine data types for validation + a_datatype = self.datatype + b_datatype = self.datatype + c_datatype = self.datatype - static constexpr int kBlockPerCu = 1; - static constexpr ck_tile::index_t TileParitionerGroupNum = 8; - static constexpr ck_tile::index_t TileParitionerM01 = 4; + # Special handling for certain data types + if self.datatype in ["fp8", "bf8"]: + c_datatype = "fp16" - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence, - permuteA, - permuteB>; + # Use the comprehensive validation function + return is_tile_config_valid( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + ) - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; + def _generate_trait_combinations(self): + """Generate all combinations of traits""" + if "traits" in self.config: + # Old format + traits = self.config["traits"] + pipelines = traits["pipelines"] + epilogues = traits["epilogues"] + schedulers = traits["schedulers"] - using Traits = - ck_tile::TileGemmTraits; + padding = self.config["padding"] + persistent = self.config["persistent"] - using GemmUniversalTraits = - ck_tile::TileGemmUniversalTraits; + all_combinations = list( + itertools.product( + pipelines, + epilogues, + schedulers, + padding["pad_m"], + padding["pad_n"], + padding["pad_k"], + persistent, + ) + ) - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; + # Filter out unsupported trait combinations + combinations = [] + for combo in all_combinations: + pipeline, epilogue, scheduler = combo[:3] + if is_trait_combination_valid(pipeline, epilogue, scheduler): + combinations.append(combo) + else: + logging.debug( + f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" + ) - using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}; + elif "trait_config" in self.config: + # New format + trait_config = self.config["trait_config"] - const ck_tile::index_t k_grain = args.k_batch * TileK; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + pipelines = trait_config.get("pipeline", {}).get("values", ["mem"]) + epilogues = trait_config.get("epilogue", {}).get("values", ["default"]) + schedulers = trait_config.get("scheduler", {}).get("values", ["intrawave"]) + pad_m_values = trait_config.get("pad_m", {}).get("values", [False]) + pad_n_values = trait_config.get("pad_n", {}).get("values", [False]) + pad_k_values = trait_config.get("pad_k", {}).get("values", [False]) + persistent_values = trait_config.get("persistent", {}).get( + "values", [False] + ) + + all_combinations = list( + itertools.product( + pipelines, + epilogues, + schedulers, + pad_m_values, + pad_n_values, + pad_k_values, + persistent_values, + ) + ) + + # Filter out unsupported trait combinations + combinations = [] + for combo in all_combinations: + pipeline, epilogue, scheduler = combo[:3] + if is_trait_combination_valid(pipeline, epilogue, scheduler): + combinations.append(combo) + else: + logging.debug( + f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}" + ) + else: + # Fallback to minimal default + combinations = [("mem", "default", "intrawave", False, False, False, False)] + + return combinations + + def _get_dtype_string(self): + """Get C++ type string for datatype""" + dtype_map = { + "fp16": "ck_tile::fp16_t", + "fp8": "ck_tile::fp8_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", + "fp64": "double", + } + return dtype_map.get(self.datatype, "float") + + _LAYOUT_MAP = { + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor", + } + + def _get_abc_layouts(self, layout_code: str | None = None): + """ + Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. + If layout_code is None, use self.layout. + """ + if layout_code is None: + # fall back to the instance field + layout_code = getattr(self, "layout", "") + + code = str(layout_code).strip().lower() + + if len(code) != 3 or any(ch not in self._LAYOUT_MAP for ch in code): + raise ValueError( + f"Invalid layout '{layout_code}'. " + "Use a 3-letter code with 'r'/'c' (e.g., rcr, ccr, crr, rrr)." + ) + + a_layout = self._LAYOUT_MAP[code[0]] + b_layout = self._LAYOUT_MAP[code[1]] + c_layout = self._LAYOUT_MAP[code[2]] + return a_layout, b_layout, c_layout + + def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True): + """Generate a single kernel instance""" + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo + + # Create kernel name with proper boolean capitalization + kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" + + # Create tile configuration string + tile_str = ( + f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + ) + tile_str += ( + f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + ) + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + kernel_name += f"_{tile_str}" + + # Map pipeline names to the correct pipeline implementation + pipeline_impl_map = { + "mem": "ck_tile::GemmPipelineAgBgCrMem", + "compv3": "ck_tile::GemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", + } + + # Map scheduler names to the correct enum values + scheduler_type_map = { + "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave", + "interwave": "ck_tile::GemmPipelineScheduler::Interwave", + "default": "ck_tile::GemmPipelineScheduler::Default", + } + + # Determine accumulator type based on datatype + acc_type = "float" + if self.datatype in ["int8", "int4"]: + acc_type = "ck_tile::int32_t" + + # Determine output type + c_type = self._get_dtype_string() + if self.datatype in ["fp8", "bf8"]: + c_type = "ck_tile::fp16_t" + + # Determine layouts based on self.layout + a_layout, b_layout, c_layout = self._get_abc_layouts() + + # Map pipeline names to base pipeline for hot loop detection + base_pipeline_map = { + "mem": "ck_tile::BaseGemmPipelineAgBgCrMem", + "compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4", + } + + # Generate kernel instance code using the correct API + pragma_line = "#pragma once\n" if is_header else "" + instance_code = f"""// Generated kernel instance for {kernel_name} +{pragma_line} +#include +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" + +using ADataType = {self._get_dtype_string()}; +using BDataType = {self._get_dtype_string()}; +using AccDataType = {acc_type}; +using CDataType = {c_type}; + +using ALayout = {a_layout}; +using BLayout = {b_layout}; +using CLayout = {c_layout}; + +// Kernel name for display +constexpr const char* KERNEL_NAME = "{kernel_name}"; + +// Wrapper for simplified launch interface +struct SelectedKernel {{ + // Tile configuration + static constexpr ck_tile::index_t BlockSize = 256; + static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]}; + static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]}; + static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]}; + static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]}; + static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]}; + static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]}; + static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]}; + static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]}; + static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]}; + + // Traits + static constexpr bool kPadM = {"true" if pad_m == "true" else "false"}; + static constexpr bool kPadN = {"true" if pad_n == "true" else "false"}; + static constexpr bool kPadK = {"true" if pad_k == "true" else "false"}; + static constexpr bool TransposeC = false; + static constexpr bool UsePersistentKernel = {"true" if persistent == "true" else "false"}; + static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"}; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool Preshuffle = false; + static constexpr ck_tile::index_t NumWaveGroups = 1; + + // Tile shape + using TileShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence, + false, false>; + + // Tile partitioner + using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + + // Traits + using Traits = ck_tile::TileGemmTraits; + + // Pipeline problem + using GemmPipelineProblem = ck_tile::GemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + Traits>; + + // Base pipeline for hot loop detection + using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseGemmPipelineAgBgCrMem")}; + + static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ + const ck_tile::index_t k_grain = args.k_batch * TileK; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - + float ave_time{{0}}; const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{ constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = {SCHEDULER_MAP[scheduler]}; - constexpr auto memory_operation = memory_operation_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Intrawave")}; + [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + ck_tile::TileGemmUniversalTraits, + scheduler, + has_hot_loop_v, + tail_number_v>; + + using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::GemmPipelineAgBgCrCompV3")}; + + // Epilogue +""" - using GemmPipeline = {PIPELINE_MAP[pipeline][1]}; - {EPILOGUE_MAP[epilogue]} - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + # Add epilogue configuration based on type + if epilogue == "cshuffle": + instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + memory_operation, // MemoryOperation_ + NumWaveGroups>; // kNumWaveGroups_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue; +""" + else: # default epilogue + instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + kPadM, + kPadN, + WarpTileM, // kMPerXdl_ + WarpTileN, // kNPerXdl_ + WarpTileK, // kKPerXdl_ + TransposeC>; // isCTransposed_ + + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue; +""" - if(!Kernel::IsSupportedArgument(kargs)) - {{ + instance_code += f""" + + // Kernel type + using GemmKernel = ck_tile::GemmKernel; + + // Make kernel arguments + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); }} - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = {'Kernel::MaxOccupancyGridSize(stream)' if persistent == 'true' else 'Kernel::GridSize(args.M, args.N, args.k_batch)'}; - - if(stream.log_level_ > 0) - {{ - std::cout << "Launching kernel with args:" - << " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" - << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" - << std::endl; - }} - - if(stream.flush_cache_) - {{ - std::cout << "Flushing cache..." << std::endl; - static constexpr ck_tile::index_t APackedSize = - std::is_same_v ? 2 : 1; - static constexpr ck_tile::index_t BPackedSize = - std::is_same_v ? 2 : 1; - - auto is_row_major = [](auto layout_) {{ - return ck_tile::bool_constant, - ck_tile::tensor_layout::gemm::RowMajor>>{{}}; - }}; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{{}}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{{}}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; - auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - - ck_tile::RotatingMemWrapper rotating_mem( - kargs.as_ptr[0], kargs.bs_ptr[0], stream.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() {{ - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_)); - }}; - ave_time = ck_tile::launch_kernel_time_mask( - stream, - run_flush_cache, - ck_tile::make_kernel( - Kernel{{}}, grids, blocks, 0, kargs)); - }} - else{{ - ave_time = ck_tile::launch_kernel(stream, - ck_tile::make_kernel( - Kernel{{}}, grids, blocks, 0, kargs)); + + // Get grid and block sizes + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent == "true" else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + if(stream.log_level_ > 0) {{ + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n' + << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" + << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" + << std::endl; }} + + // Launch kernel + constexpr int kBlockPerCu = 1; + ave_time = ck_tile::launch_kernel( + stream, + ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + return ave_time; - }}; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{ @@ -373,484 +609,324 @@ struct GemmKernel {{ }}; BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; }} - - static std::string get_name() {{ - return std::string("gemm_") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) + - "_" + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + "_" + - std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + "_" + - "{pad_m}" + "_" + - "{pad_n}" + "_" + - "{pad_k}" + "_" + - "{pipeline}" + "_" + - "{epilogue}" + "_" + - "{scheduler}" + "_" + - "{persistent}"; - }} }}; """ - def _generate_common_instance_header_file(self): - """Generate common instance header into file.""" - content = """// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#pragma once -""" - for trait in self.valid_trait_names: - content += f'#include "gemm_{trait}.hpp"\n' - (self.output_dir / "gemm_instances.hpp").write_text(content) + return kernel_name, instance_code - def is_tile_valid(self, tile: tuple, trait: str) -> bool: - """Check if the tile configuration is valid for the given trait.""" - ( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - ) = tile - pipeline, *_ = trait.split("_") + def generate_individual(self, num_workers=None): + """Generate individual kernel files for separate compilation with parallel processing""" + if num_workers is None: + num_workers = min( + multiprocessing.cpu_count(), 8 + ) # Limit to avoid memory issues - # Parameter validity check - invalid_params = [] - if (warp_m, warp_n, warp_k) not in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]: - invalid_params.append( - f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})" - ) - if (warp_m * warp_tile_m) == 0: - invalid_params.append(f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})") - if (warp_n * warp_tile_n) == 0: - invalid_params.append(f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})") - if (warp_k * warp_tile_k) == 0: - invalid_params.append(f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})") + tile_configs = self._get_tile_configs() + trait_combos = self._generate_trait_combinations() - if invalid_params: - logging.debug( - f"Trait: [{trait}], Invalid warp configuration: {', '.join(invalid_params)}. " - f"Parameter combination: warp=({warp_m},{warp_n},{warp_k}), " - f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})" - ) - return False - # Dimension alignment check - alignment_issues = [] - if tile_m % (warp_m * warp_tile_m) != 0: - alignment_issues.append( - f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" - ) - if tile_n % (warp_n * warp_tile_n) != 0: - alignment_issues.append( - f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" - ) - if tile_k % (warp_k * warp_tile_k) != 0: - alignment_issues.append( - f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" - ) - - if alignment_issues: - logging.debug( - f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. " - f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " - f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" - ) - return False - - # LDS capacity verification - matrix_a_size = (tile_m * tile_k) * element_size( - self.config.problem.datatype_map["matrix_a"] - ) - matrix_b_size = (tile_n * tile_k) * element_size( - self.config.problem.datatype_map["matrix_b"] - ) - total_tile_in_lds = matrix_a_size + matrix_b_size - - max_tile_size = 2**15 if pipeline == "compv4" else 2**16 - - if total_tile_in_lds > max_tile_size: - logging.debug( - f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " - f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" - f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" - f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B" - ) - return False - - # Warp combination validation - warp_tile_key = f"{self.config.problem.datatype_map['matrix_a']}_{self.config.problem.datatype_map['matrix_b']}_{self.config.problem.datatype_map['matrix_c']}" - current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] - - gpu_name = get_gpu_name_by_id(0) - - gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {}) - if not gpu_warp_tile_key: - logging.debug( - f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check." - ) - return False - - allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, []) - if not allowed_combinations: - logging.debug( - f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check." - ) - return False - - if current_combination not in allowed_combinations: - logging.debug( - f"Trait: [{trait}], Invalid warp combination: {current_combination} not in allowed list. " - f"Valid combinations for data type '{warp_tile_key}': {allowed_combinations}" - ) - return False - - return True - - def _get_valid_trait_tile_combinations(self): - def get_tile_value(tile_param): - return ( - tile_param.generate_candidates() - if isinstance(tile_param, RangeConfigParam) - else tile_param.values - ) - - tile_group = list( - itertools.product( - get_tile_value(self.config.tile_config.tile_m), - get_tile_value(self.config.tile_config.tile_n), - get_tile_value(self.config.tile_config.tile_k), - ) - ) - - warp_group = list( - itertools.product( - get_tile_value(self.config.tile_config.warp_m), - get_tile_value(self.config.tile_config.warp_n), - get_tile_value(self.config.tile_config.warp_k), - ) - ) - - warp_tile_group = list( - itertools.product( - get_tile_value(self.config.tile_config.warp_tile_m), - get_tile_value(self.config.tile_config.warp_tile_n), - get_tile_value(self.config.tile_config.warp_tile_k), - ) - ) - - tile_params = { - t + w + wt for t in tile_group for w in warp_group for wt in warp_tile_group - } - - for trait in self.valid_trait_names: - tile_valid_params = [ - tile for tile in tile_params if self.is_tile_valid(tile, trait) - ] - - if trait not in self.valid_trait_tile_combinations: - self.valid_trait_tile_combinations[trait] = [] - self.valid_trait_tile_combinations[trait].append(tile_valid_params) - - def _generate_instantiation_source_files(self): - """Generate kernel instance instantiation source files""" - tile_map = {} - for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): - for tile in tile_valid_params: - for ( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - ) in tile: - key = f"{tile_m}x{tile_n}x{tile_k}x{warp_m}x{warp_n}x{warp_k}" - value = f"{warp_tile_m}x{warp_tile_n}x{warp_tile_k}" - if key not in tile_map: - tile_map[key] = set() - tile_map[key].add(value) - - files_listed = 0 - for trait, _ in self.valid_trait_tile_combinations.items(): - for block_tile, warp_tiles in tile_map.items(): - tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map( - int, block_tile.split("x") + # Prepare work items for parallel processing + work_items = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + work_items.append( + ( + tile_config, + trait_combo, + self.working_path, + self.datatype, + self.layout, + ) ) - content = f""" -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - - -#include "gemm_{trait}.hpp" - -""" - for warp_tile in warp_tiles: - warp_tile_m, warp_tile_n, warp_tile_k = map( - int, warp_tile.split("x") - ) - - sparse = ( - self.config.problem.datatype_map["matrix_a"] == "fp16" - and self.config.problem.datatype_map["matrix_b"] == "fp16" - and self.config.problem.datatype_map["matrix_c"] == "fp16" - and ( - ( - warp_tile_m == 32 - and warp_tile_n == 32 - and warp_tile_k == 16 - ) - or ( - warp_tile_m == 16 - and warp_tile_n == 16 - and warp_tile_k == 32 - ) - ) - ) - if sparse: - files_listed = files_listed + 1 - content = ( - content - + f""" -template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;""" - ) - files_listed = files_listed + 1 - content = ( - content - + f""" -template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;""" - ) - content += f""" -""" - ( - self.output_dir - / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp" - ).write_text(content) - print(f"Generated {files_listed} kernel instances in total.") - - def _generate_dispatcher_file(self): - """Generate the code block of dispatch mechanism.""" - content = """ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include - -#include "gemm_common.hpp" -#include "gemm_instances.hpp" - -/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a -/// specific kernel instance based on the provided settings. -struct KernelTraits -{ - /// @brief The name of the pipeline. - std::string pipeline; - /// @brief The name of the scheduler (e.g., "intrawave", "interwave"). - std::string scheduler; - /// @brief The name of the epilogue (e.g., "cshuffle", "default"). - std::string epilogue; - /// @brief Indicates whether padding is applied to the M dimension. - bool pad_m; - /// @brief Indicates whether padding is applied to the N dimension. - bool pad_n; - /// @brief Indicates whether padding is applied to the K dimension. - bool pad_k; - /// @brief Indicates whether the kernel is persistent. - bool persistent; -}; - -struct GemmDispatcher { - static auto& get_kernel_map() { - // Use a static local variable - static std::unordered_map< - std::string, - std::vector(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>> - kernel_map; - return kernel_map; - } - - static void init([[maybe_unused]]bool structured_sparsity) { - auto& kernel_map = get_kernel_map(); - if(!kernel_map.empty()) return; - \n""" - - for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): - content += f""" kernel_map["{trait}"] = {{""" - for _, tile in enumerate(tile_valid_params): - for j in range(len(tile)): - ( - tile_m, - tile_n, - tile_k, - warp_m, - warp_n, - warp_k, - warp_tile_m, - warp_tile_n, - warp_tile_k, - ) = tile[j] - content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """ - content += f""" - if(structured_sparsity){{ // SMFMA""" - sparse = ( - self.config.problem.datatype_map["matrix_a"] == "fp16" - and self.config.problem.datatype_map["matrix_b"] == "fp16" - and self.config.problem.datatype_map["matrix_c"] == "fp16" - and ( - ( - warp_tile_m == 32 - and warp_tile_n == 32 - and warp_tile_k == 16 - ) - or ( - warp_tile_m == 16 - and warp_tile_n == 16 - and warp_tile_k == 32 - ) - ) - ) - content += f""" - return run_kernel<{trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, {BOOL_MAP(sparse)}>>(args, stream);""" - content += f""" - }} else {{""" - content += f""" - return run_kernel<{trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, {BOOL_MAP(False)}>>(args, stream);""" - content += f""" - }} """ - - if j == len(tile) - 1: - content += f""" - }} """ - else: - content += f""" - }}, """ - content += f""" - }};\n """ - - content += """ } - - template - static std::tuple run_kernel(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) - { - std::string name = Kernel::get_name(); - float avg_time = Kernel::launch(args, stream); - - return std::make_tuple(name, avg_time); - } - - - static auto dispatch(bool structured_sparsity, const KernelTraits& trait) { - init(structured_sparsity); - const std::string key = assemble_key(trait); - auto& kernel_map = get_kernel_map(); - if(auto it = kernel_map.find(key); it != kernel_map.end()) - { - return it->second; - } - throw std::runtime_error("No suitable kernel found: " + key); - } - -private: - static std::string assemble_key(const KernelTraits &trait) { - return std::string(trait.pipeline) + "_" + - trait.epilogue + "_" + - trait.scheduler + "_" + - (trait.pad_m ? "true" : "false") + "_" + - (trait.pad_n ? "true" : "false") + "_" + - (trait.pad_k ? "true" : "false") + "_" + - (trait.persistent ? "true" : "false"); - } -}; - -""" - (self.output_dir / "gemm_dispatcher.hpp").write_text(content) - - -def do_list_blobs( - args: argparse.Namespace, user_provide_config: Optional[GemmConfig] = None -): - generator = GemmCodeGenerator(args.working_path, user_provide_config) - generator.list_all_trait_names() - - -def do_gen_blobs( - args: argparse.Namespace, user_provide_config: Optional[GemmConfig] = None -): - generator = GemmCodeGenerator(args.working_path, user_provide_config) - generator.generate_all_instance_files() - - -def main(args): - gemm_config = ( - GemmConfig.from_json(args.config_json, args.datatype, args.layout) - if args.config_json is not None - else args.config_json - ) - - if args.list_blobs: - do_list_blobs(args, gemm_config) - elif args.gen_blobs: - do_gen_blobs(args, gemm_config) - else: - logging.warning( - "No mode specified (use --list_blobs or --gen_blobs). Generating by default..." + print( + f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." ) - do_gen_blobs(args, gemm_config) + print(f" Tile configs: {len(tile_configs)}") + print(f" Trait combinations: {len(trait_combos)}") + print(f" Total kernels: {len(work_items)}") + + # Show first few work items for debugging + if work_items: + print(" First work item example:") + tile_config, trait_combo = work_items[0][:2] + print(f" Tile config: {tile_config}") + print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits + + # Process work items in parallel + kernel_list = [] + completed = 0 + + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_workers + ) as executor: + # Submit all work items + print(f" Submitting {len(work_items)} tasks to executor...") + future_to_item = { + executor.submit(_generate_single_kernel_individual, item): item + for item in work_items + } + print(" All tasks submitted, waiting for completion...") + + # Collect results with progress reporting + for future in concurrent.futures.as_completed(future_to_item): + completed += 1 + if completed % 100 == 0 or completed == len(work_items): + print( + f" Progress: {completed}/{len(work_items)} kernels generated" + ) + + try: + result = future.result() + if result: + kernel_list.append(result) + except Exception as exc: + item = future_to_item[future] + print(f"Kernel generation failed for {item}: {exc}") + + # Sort kernel list for consistent ordering + kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name + + # Generate CMake include file for individual targets + self._generate_cmake_individual_targets(kernel_list) + + print( + f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" + ) + + def _generate_cmake_individual_targets(self, kernel_list): + """Generate CMake include file that creates individual targets""" + cmake_code = f"""# Generated CMake file for individual GEMM targets +# Datatype: {self.datatype}, Layout: {self.layout} + +""" + + for kernel_name, trait_combo, tile_config in kernel_list: + pipeline, epilogue, scheduler = trait_combo[:3] + + # Format tile config for CMake function + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join( + str(x) for x in trait_combo[3:] + ) + + cmake_code += f'create_individual_gemm_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n' + + # Write CMake include file + with open(self.working_path / "gemm_individual_targets.cmake", "w") as f: + f.write(cmake_code) + + def write_kernel_list(self): + """Write kernel list to file for CMake to read (with comprehensive validation)""" + # Get configurations using comprehensive validation + tile_configs = self._get_tile_configs(fast_mode=False) + trait_combos = self._generate_trait_combinations() + + kernel_list = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo + + # Create kernel name with proper boolean capitalization + kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" + + # Create tile configuration string + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + kernel_name += f"_{tile_str}" + + kernel_list.append( + { + "name": kernel_name, + "tile_config": tile_config, + "trait_combo": trait_combo, + } + ) + + # Write kernel count + with open(self.working_path / "gemm_kernel_count.txt", "w") as f: + f.write(str(len(kernel_list))) + + # Write kernel list + with open(self.working_path / "gemm_kernel_list.txt", "w") as f: + for kernel in kernel_list: + # Format: kernel_name|tile_config|trait_combo + tile_config = kernel["tile_config"] + trait_combo = kernel["trait_combo"] + + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = ( + f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_" + + "_".join(str(x) for x in trait_combo[3:]) + ) + + f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n") + + print(f"Listed {len(kernel_list)} kernel configurations") + + def run(self, num_workers=None): + """Run the builder to generate individual kernel files""" + # Generate individual kernel files + self.generate_individual(num_workers) -if __name__ == "__main__": +def _generate_single_kernel_individual(work_item): + """Worker function to generate a single individual kernel file""" + tile_config, trait_combo, working_path, datatype, layout = work_item + + # Create a temporary builder instance for this worker + builder = GemmKernelBuilder(working_path, datatype, layout) + + try: + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Create simplified filename without the "gemm_" prefix + # Remove "gemm_" from the beginning of kernel_name for the filename + simplified_name = kernel_name + if simplified_name.startswith("gemm_"): + simplified_name = simplified_name[5:] # Remove "gemm_" prefix + + # Write individual header file + header_file = working_path / f"gemm_single_{simplified_name}.hpp" + with open(header_file, "w") as f: + f.write(instance_code) + + return (kernel_name, trait_combo, tile_config) + except Exception as e: + print(f"Error generating individual kernel: {e}") + return None + + +def main(): parser = argparse.ArgumentParser( - prog="generate", - description="gen API for CK gemm kernel", + description="GEMM kernel instance builder with parallel support" ) + parser.add_argument("--working_path", required=True, help="Working directory path") parser.add_argument( - "-w", - "--working_path", - default="./", - required=False, - help="The path where all the blobs are going to be generated", - ) - parser.add_argument( - "-j", - "--config_json", - required=False, - help="Path to the json which contains the configurations that user provide", - ) - parser.add_argument( - "-d", "--datatype", required=True, - help="Specify what datatype to use for the kernel generation, e.g. fp16, bf16, int8, fp8, bf8", + choices=["fp16", "fp8", "bf16", "fp32", "fp64"], + help="Data type", ) parser.add_argument( - "-ly", "--layout", required=True, - help="Specify what layout to use for the kernel generation, e.g. rcr, rrr", + choices=["rcr", "rrr", "ccr", "crr"], + help="Matrix layout", + ) + parser.add_argument("--config_json", help="Configuration JSON file") + parser.add_argument( + "--num_workers", type=int, help="Number of parallel workers (default: auto)" ) parser.add_argument( - "-l", - "--list_blobs", - action="store_true", - help="List all kernel instances to file", + "--gen_individual", action="store_true", help="Generate individual kernel files" ) parser.add_argument( - "-g", - "--gen_blobs", + "--gen_single", action="store_true", help="Generate a single kernel file" + ) + parser.add_argument("--kernel_name", help="Kernel name for single generation") + parser.add_argument( + "--tile_config", help="Tile configuration string for single generation" + ) + parser.add_argument( + "--trait_combo", help="Trait combination string for single generation" + ) + parser.add_argument( + "--list_kernels", action="store_true", - help="Generate all kernel instances into different files", + help="List kernel configurations without generating files", ) args = parser.parse_args() - main(args) + # Create builder + builder = GemmKernelBuilder( + args.working_path, args.datatype, args.layout, args.config_json + ) + + if args.list_kernels: + # Fast listing mode - just write kernel list without generating files + builder.write_kernel_list() + elif args.gen_single: + # Generate a single kernel file + if not args.kernel_name or not args.tile_config or not args.trait_combo: + parser.error( + "--gen_single requires --kernel_name, --tile_config, and --trait_combo" + ) + + # Parse tile config + tile_parts = args.tile_config.split("_") + tile_dims = tile_parts[0].split("x") + warp_dims = tile_parts[1].split("x") + warp_tile_dims = tile_parts[2].split("x") + + tile_config = { + "tile_m": int(tile_dims[0]), + "tile_n": int(tile_dims[1]), + "tile_k": int(tile_dims[2]), + "warp_m": int(warp_dims[0]), + "warp_n": int(warp_dims[1]), + "warp_k": int(warp_dims[2]), + "warp_tile_m": int(warp_tile_dims[0]), + "warp_tile_n": int(warp_tile_dims[1]), + "warp_tile_k": int(warp_tile_dims[2]), + } + + # Parse trait combo + trait_parts = args.trait_combo.split("_") + trait_combo = ( + trait_parts[0], # pipeline + trait_parts[1], # epilogue + trait_parts[2], # scheduler + trait_parts[3] == "True", # pad_m + trait_parts[4] == "True", # pad_n + trait_parts[5] == "True", # pad_k + trait_parts[6] == "True", # persistent + ) + + # Generate the kernel + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Write the file + simplified_name = kernel_name + if simplified_name.startswith("gemm_"): + simplified_name = simplified_name[5:] + + header_file = builder.working_path / f"gemm_single_{simplified_name}.hpp" + with open(header_file, "w") as f: + f.write(instance_code) + + print(f"Generated {header_file}") + + elif args.gen_individual: + # Generate all individual kernel files + builder.run(args.num_workers) + else: + parser.error( + "Must specify one of: --list_kernels, --gen_individual, or --gen_single" + ) + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index 634e19de6e..bbf0c92e67 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -20,6 +20,25 @@ class GemmProfiler return instance; } + // Overload for single kernel benchmarking + void benchmark(GemmProblem& gemm_problem, + std::function + kernel_func) + { + // Create a vector with a single callable that returns both name and time + std::vector(ck_tile::GemmHostArgs&, + const ck_tile::stream_config&)>> + callables; + + callables.push_back( + [kernel_func](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) { + float time = kernel_func(args, stream); + return std::make_tuple(std::string(KERNEL_NAME), time); + }); + + benchmark(gemm_problem, callables); + } + void benchmark(GemmProblem& gemm_problem, std::vector( ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables) @@ -161,7 +180,7 @@ class GemmProfiler kernel_instance.perf_result_.tflops_ = static_cast(flop) / 1.E9 / avg_time; kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time; - if(setting_.log_ > 0) + if(setting_.log_ > 0 && !setting_.json_output_) { std::cout << kernel_instance << std::endl; } @@ -199,10 +218,18 @@ class GemmProfiler b.perf_result_, a.perf_result_, metric); }); - std::cout << "**********************************" << std::endl; - std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" - << "The best kernel instance is: " << kernel_instance << std::endl; - std::cout << "**********************************" << std::endl; + if(setting_.json_output_) + { + // Output clean JSON only + std::cout << kernel_instance << std::endl; + } + else + { + std::cout << "**********************************" << std::endl; + std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" + << "Current kernel performance is: " << kernel_instance << std::endl; + std::cout << "**********************************" << std::endl; + } if(!setting_.csv_filename_.empty()) { diff --git a/tile_engine/ops/gemm/test_benchmark.sh b/tile_engine/ops/gemm/test_benchmark.sh new file mode 100755 index 0000000000..1fb7c163af --- /dev/null +++ b/tile_engine/ops/gemm/test_benchmark.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +# Test script for tile engine GEMM benchmarks +# This script demonstrates how to run the new individual benchmark executables + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Find the build directory +if [ -z "$1" ]; then + # Try to find build directory automatically + BUILD_DIR=$(find /root/workspace/composable_kernel -name "test_gemm_fix" -type d 2>/dev/null | head -1) + if [ -z "$BUILD_DIR" ]; then + echo -e "${RED}Error: Could not find build directory. Please provide it as first argument.${NC}" + echo "Usage: $0 " + exit 1 + fi +else + BUILD_DIR="$1" +fi + +echo -e "${GREEN}Using build directory: $BUILD_DIR${NC}" + +# Check if bin directory exists +if [ ! -d "$BUILD_DIR/bin" ]; then + echo -e "${RED}Error: bin directory not found in $BUILD_DIR${NC}" + exit 1 +fi + +# Find all benchmark executables +echo -e "${YELLOW}Finding benchmark executables...${NC}" +BENCHMARKS=$(find "$BUILD_DIR/bin" -name "benchmark_gemm_*" -type f 2>/dev/null) + +if [ -z "$BENCHMARKS" ]; then + echo -e "${RED}No benchmark executables found in $BUILD_DIR/bin${NC}" + echo "Please build some benchmarks first with:" + echo " cd $BUILD_DIR" + echo " make benchmark_gemm_" + exit 1 +fi + +# Count benchmarks +NUM_BENCHMARKS=$(echo "$BENCHMARKS" | wc -l) +echo -e "${GREEN}Found $NUM_BENCHMARKS benchmark executable(s)${NC}" + +# Test sizes +SIZES=(512 1024 2048) + +# Results file +RESULTS_FILE="benchmark_results_$(date +%Y%m%d_%H%M%S).csv" + +echo -e "${YELLOW}Running benchmarks...${NC}" +echo "Results will be saved to: $RESULTS_FILE" + +# Run each benchmark +COUNTER=0 +for BENCH in $BENCHMARKS; do + COUNTER=$((COUNTER + 1)) + BENCH_NAME=$(basename "$BENCH") + echo -e "\n${GREEN}[$COUNTER/$NUM_BENCHMARKS] Running: $BENCH_NAME${NC}" + + for SIZE in "${SIZES[@]}"; do + echo -e " Testing size: ${SIZE}x${SIZE}x${SIZE}" + + # Run with verification + "$BENCH" -m=$SIZE -n=$SIZE -k=$SIZE -verify=2 -warmup=10 -repeat=20 \ + -csv_filename="$RESULTS_FILE" -csv_format=simple \ + 2>&1 | grep -E "(Time:|Performance:|Verification:|Error)" + + if [ ${PIPESTATUS[0]} -ne 0 ]; then + echo -e " ${RED}Benchmark failed!${NC}" + fi + done +done + +echo -e "\n${GREEN}Benchmark testing complete!${NC}" +echo "Results saved to: $RESULTS_FILE" + +# Show summary if CSV file exists +if [ -f "$RESULTS_FILE" ]; then + echo -e "\n${YELLOW}Summary of results:${NC}" + echo "Number of tests: $(tail -n +2 "$RESULTS_FILE" | wc -l)" + echo "Successful tests: $(grep -c "true" "$RESULTS_FILE")" + echo "Failed tests: $(grep -c "false" "$RESULTS_FILE")" +fi + +# Example of running a specific benchmark with different options +echo -e "\n${YELLOW}Example commands for manual testing:${NC}" +echo "# Basic run:" +echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024" +echo "" +echo "# With CPU verification:" +echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024 -verify=1" +echo "" +echo "# JSON output for parsing:" +echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024 -json_output=true" +echo "" +echo "# Performance testing with TFLOPS metric:" +echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=4096 -n=4096 -k=4096 -warmup=100 -repeat=200 -metric=1" diff --git a/tile_engine/ops/gemm/test_validation.py b/tile_engine/ops/gemm/test_validation.py new file mode 100644 index 0000000000..1c9a0ff0ca --- /dev/null +++ b/tile_engine/ops/gemm/test_validation.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +""" +Test script to verify that the validation logic is working correctly. +""" + +from validation_utils import ( + is_tile_config_valid, + is_trait_combination_valid, + validate_warp_tile_combination, + get_gpu_name_by_id, +) + + +def test_warp_tile_validation(): + """Test warp tile combination validation""" + print("Testing warp tile combination validation...") + + # Get GPU name + gpu_name = get_gpu_name_by_id(0) + print(f"Detected GPU: {gpu_name}") + + # Test cases for fp16 + test_cases = [ + # (warp_tile_m, warp_tile_n, warp_tile_k, expected_valid) + ([4, 64, 8], False), # Invalid - not in supported list + ([4, 64, 16], True), # Valid + ([32, 32, 8], True), # Valid + ([16, 16, 16], True), # Valid + ([32, 32, 16], True), # Valid + ([16, 16, 32], True), # Valid + ([64, 4, 16], True), # Valid + ([128, 128, 128], False), # Invalid - too large + ] + + print("\nTesting fp16 warp tile combinations:") + for (warp_tile_m, warp_tile_n, warp_tile_k), expected in test_cases: + valid, msg = validate_warp_tile_combination( + warp_tile_m, warp_tile_n, warp_tile_k, "fp16", "fp16", "fp16", gpu_name + ) + status = "PASS" if valid == expected else "FAIL" + print(f" [{warp_tile_m}, {warp_tile_n}, {warp_tile_k}]: {valid} - {status}") + if not valid and msg: + print(f" Reason: {msg}") + + +def test_trait_combinations(): + """Test trait combination validation""" + print("\n\nTesting trait combination validation...") + + test_cases = [ + # (pipeline, epilogue, scheduler, expected_valid) + ("mem", "default", "intrawave", True), + ("mem", "cshuffle", "intrawave", True), + ("compv3", "default", "interwave", False), # Invalid combination + ("compv3", "cshuffle", "interwave", False), # Invalid combination + ("compv4", "default", "interwave", False), # Invalid combination + ("compv4", "cshuffle", "interwave", False), # Invalid combination + ("compv3", "default", "intrawave", True), + ("compv4", "cshuffle", "intrawave", True), + ] + + print("\nTesting trait combinations:") + for pipeline, epilogue, scheduler, expected in test_cases: + valid = is_trait_combination_valid(pipeline, epilogue, scheduler) + status = "PASS" if valid == expected else "FAIL" + print(f" {pipeline}-{epilogue}-{scheduler}: {valid} - {status}") + + +def test_full_tile_config_validation(): + """Test full tile configuration validation""" + print("\n\nTesting full tile configuration validation...") + + # Test case that was failing in the build + tile_m, tile_n, tile_k = 256, 256, 32 + warp_m, warp_n, warp_k = 1, 4, 1 + warp_tile_m, warp_tile_n, warp_tile_k = 4, 64, 8 + + print("\nTesting problematic configuration:") + print(f" Tile: {tile_m}x{tile_n}x{tile_k}") + print(f" Warp: {warp_m}x{warp_n}x{warp_k}") + print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}") + + valid = is_tile_config_valid( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + "fp16", + "fp16", + "fp16", + "mem", + ) + + print(f" Valid: {valid}") + print(" Expected: False (warp tile [4, 64, 8] is not supported for fp16)") + + # Test a valid configuration + warp_tile_k = 16 # Change to valid value + print("\nTesting corrected configuration:") + print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}") + + valid = is_tile_config_valid( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + "fp16", + "fp16", + "fp16", + "mem", + ) + + print(f" Valid: {valid}") + print(" Expected: True") + + +def main(): + """Run all tests""" + print("=" * 60) + print("GEMM Validation Test Suite") + print("=" * 60) + + test_warp_tile_validation() + test_trait_combinations() + test_full_tile_config_validation() + + print("\n" + "=" * 60) + print("Test suite completed") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/gemm/validation_utils.py b/tile_engine/ops/gemm/validation_utils.py new file mode 100644 index 0000000000..4948fd5744 --- /dev/null +++ b/tile_engine/ops/gemm/validation_utils.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +""" +Validation utilities for GEMM kernel generation. +Extracted from tile_engine_develop for consistency. +""" + +import subprocess +import re +from functools import lru_cache +import logging + +# Element size mapping for different data types +ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "int8": 1, + "fp8": 1, + "bf8": 1, + "int4": 0.5, + "int32": 4, + "fp32": 4, + "fp64": 8, +} + +# Supported warp tile combinations for different GPU architectures and data types +WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx90a": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], + }, + "gfx942": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + "gfx950": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_bf8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, +} + +# Unsupported trait combinations +TRAIT_UNSUPPORTED_COMBINATIONS = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), +} + + +def element_size(data_type: str) -> float: + """Calculate the size (in bytes) of a single element for given data type.""" + data_type = data_type.lower() + if data_type not in ELEMENT_SIZE_MAP: + raise ValueError(f"Unsupported data type: {data_type}") + return ELEMENT_SIZE_MAP[data_type] + + +GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)") + + +@lru_cache(maxsize=1) +def get_gpu_name_by_id(gpu_id: int = 0) -> str: + """Retrieve GPU name (e.g. gfx90a) by device ID""" + try: + output = subprocess.check_output( + ["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5 + ) + if matches := GPU_NAME_PATTERN.finditer(output): + gpu_list = [m.group(1) for m in matches] + return gpu_list[gpu_id] if gpu_id < len(gpu_list) else "" + + return "" + + except subprocess.CalledProcessError as e: + logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}") + except FileNotFoundError: + logging.debug("ROCm tools not installed (requires rocminfo)") + except subprocess.TimeoutExpired: + logging.debug("GPU query timeout (5s)") + except Exception as e: + logging.debug(f"GPU detection error: {str(e)}") + + return "" + + +def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool: + """Check if a trait combination is valid.""" + return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS + + +def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool: + """Validate warp configuration.""" + return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + + +def validate_dimension_alignment( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, +) -> tuple[bool, list[str]]: + """Check if tile dimensions are properly aligned with warp dimensions.""" + alignment_issues = [] + + if tile_m % (warp_m * warp_tile_m) != 0: + alignment_issues.append( + f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" + ) + if tile_n % (warp_n * warp_tile_n) != 0: + alignment_issues.append( + f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" + ) + if tile_k % (warp_k * warp_tile_k) != 0: + alignment_issues.append( + f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" + ) + + return len(alignment_issues) == 0, alignment_issues + + +def validate_lds_capacity( + tile_m: int, + tile_n: int, + tile_k: int, + a_datatype: str, + b_datatype: str, + pipeline: str, +) -> tuple[bool, str]: + """Validate LDS capacity requirements.""" + matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) + matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) + total_tile_in_lds = matrix_a_size + matrix_b_size + + max_tile_size = 2**15 if pipeline == "compv4" else 2**16 + + if total_tile_in_lds > max_tile_size: + error_msg = ( + f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " + f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" + f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" + f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B" + ) + return False, error_msg + + return True, "" + + +def validate_warp_tile_combination( + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + gpu_name: str = None, +) -> tuple[bool, str]: + """Validate warp tile combination against GPU-specific supported combinations.""" + if gpu_name is None: + gpu_name = get_gpu_name_by_id(0) + + # Construct the key for looking up supported combinations + warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}" + current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] + + # Check if we have GPU-specific combinations + gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + if not gpu_warp_tile_combinations: + # If GPU not recognized, try to be permissive but log warning + logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") + return True, "" + + # Check if we have combinations for this data type combination + allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, []) + if not allowed_combinations: + # For data type combinations not in the list, be permissive + logging.debug( + f"No warp tile combinations found for data types: {warp_tile_key}" + ) + return True, "" + + # Check if current combination is in the allowed list + if current_combination not in allowed_combinations: + error_msg = ( + f"Invalid warp tile combination: {current_combination} not in allowed list. " + f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}" + ) + return False, error_msg + + return True, "" + + +def is_tile_config_valid( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + pipeline: str, + trait_name: str = None, +) -> bool: + """ + Comprehensive tile configuration validation. + Returns True if configuration is valid, False otherwise. + """ + # Basic sanity checks + if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: + return False + if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: + return False + if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: + return False + + # Check that warp tiles fit within block tiles + if warp_m * warp_tile_m > tile_m: + return False + if warp_n * warp_tile_n > tile_n: + return False + if warp_k * warp_tile_k > tile_k: + return False + + # Validate warp configuration + if not validate_warp_configuration(warp_m, warp_n, warp_k): + logging.debug( + f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})" + ) + return False + + # Validate dimension alignment + is_aligned, alignment_issues = validate_dimension_alignment( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) + if not is_aligned: + logging.debug( + f"Dimension alignment failed: {', '.join(alignment_issues)}. " + f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " + f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + ) + return False + + # Validate LDS capacity + lds_valid, lds_error = validate_lds_capacity( + tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline + ) + if not lds_valid: + logging.debug(f"LDS validation failed: {lds_error}") + return False + + # Validate warp tile combination + warp_tile_valid, warp_tile_error = validate_warp_tile_combination( + warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype + ) + if not warp_tile_valid: + logging.debug(f"Warp tile validation failed: {warp_tile_error}") + return False + + return True From d876e87fe45a58ab4f83b945a021ea5effb9b31d Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 1 Sep 2025 09:16:45 +0800 Subject: [PATCH 002/404] [CK_TILE] Add FAv3 fwd pipeline (#2731) * Add FAv3 fwd pipeline * Unpack v_pk_mul to hide v_mov * Avoid compiler moving l compute across phase * Sync sched_group_barrier() setting for masking cases --- example/ck_tile/01_fmha/CMakeLists.txt | 22 + .../ck_tile/01_fmha/example_fmha_fwd_v3.cpp | 492 +++++++ example/ck_tile/01_fmha/fmha_fwd_v3.cpp | 60 + example/ck_tile/01_fmha/fmha_fwd_v3.hpp | 67 + example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp | 159 +++ .../instances/fmha_fwd_v3_d128_bf16_mask.cpp | 14 + .../instances/fmha_fwd_v3_d128_bf16_nmask.cpp | 14 + .../instances/fmha_fwd_v3_d128_fp16_mask.cpp | 14 + .../instances/fmha_fwd_v3_d128_fp16_nmask.cpp | 14 + .../01_fmha/script/benchmark_fwd_v3.sh | 31 + include/ck_tile/ops/fmha.hpp | 3 + .../ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 519 +++++++ .../pipeline/block_fmha_fwd_v3_pipeline.hpp | 1198 +++++++++++++++++ ...ck_fmha_fwd_v3_pipeline_default_policy.hpp | 603 +++++++++ .../pipeline/block_fmha_pipeline_problem.hpp | 44 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 16 + 16 files changed, 3270 insertions(+) create mode 100644 example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp create mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3.cpp create mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3.hpp create mode 100644 example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp create mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp create mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp create mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp create mode 100644 example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp create mode 100755 example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index bd03aee924..5f495c76d8 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -144,6 +144,28 @@ list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS}) +# add fmha_fwd_v3 example +set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3") +message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}") + +add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp) +target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS + "${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp" +) +target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE + fmha_fwd_v3.cpp + ${FMHA_FWD_V3_INSTANCES} +) + +set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS) +list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS + -fgpu-flush-denormals-to-zero + -Wno-undefined-func-template + --save-temps +) +target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) + # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp new file mode 100644 index 0000000000..d2428e5152 --- /dev/null +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -0,0 +1,492 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fmha_fwd.hpp" +#include "fmha_fwd_v3.hpp" +#include "mask.hpp" + +auto parse_cmd_args(int argc, char* argv[]) -> std::pair +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("prec", "fp16", "data type. fp16/bf16") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "-1", + "num of head, for k/v, -1 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", "3328", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q & k") + .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)") + .insert("iperm", + "0", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "0", "permute output") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)") + .insert("v", "1", "0:no verify, 1:verify") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "30", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_pair(result, arg_parser); +} + +enum class TensorLayout +{ + bhsd, + bshd, +}; + +std::ostream& operator<<(std::ostream& stream, TensorLayout layout) +{ + switch(layout) + { + case TensorLayout::bhsd: return stream << "bhsd"; + case TensorLayout::bshd: return stream << "bshd"; + default: return stream << "unknown"; + } +} + +struct Problem +{ + explicit Problem(const ck_tile::ArgParser& args) + { + data_type = args.get_str("prec") == "fp16" + ? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16 + : ck_tile::fmha_fwd_v3_args::data_type_enum::bf16; + batch = args.get_int("b"); + seqlen_q = args.get_int("s"); + seqlen_k = args.get_int("s_k"); + if(seqlen_k < 0) + { + seqlen_k = seqlen_q; + } + nhead_q = args.get_int("h"); + nhead_kv = args.get_int("h_k"); + if(nhead_kv < 0) + { + nhead_kv = nhead_q; + } + hdim = args.get_int("d"); + softmax_scale = args.get_float("scale_s"); + if(softmax_scale == .0f) + softmax_scale = 1.0 / ck_tile::sqrt(static_cast(hdim)); + mask = mask_info::decode(args.get_str("mask"), seqlen_q, seqlen_k); + + input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; + output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; + } + + std::vector get_query_shape() const + { + if(input_layout == TensorLayout::bhsd) + { + return {batch, nhead_q, seqlen_q, hdim}; + } + else + { + return {batch, seqlen_q, nhead_q, hdim}; + } + } + + std::vector get_key_shape() const + { + if(input_layout == TensorLayout::bhsd) + { + return {batch, nhead_kv, seqlen_k, hdim}; + } + else + { + return {batch, seqlen_k, nhead_kv, hdim}; + } + } + + std::vector get_value_shape() const + { + if(input_layout == TensorLayout::bhsd) + { + return {batch, nhead_kv, seqlen_k, hdim}; + } + else + { + return {batch, seqlen_k, nhead_kv, hdim}; + } + } + + std::vector get_output_shape() const + { + if(output_layout == TensorLayout::bhsd) + { + return {batch, nhead_q, seqlen_q, hdim}; + } + else + { + return {batch, seqlen_q, nhead_q, hdim}; + } + } + + ck_tile::fmha_fwd_v3_args::data_type_enum data_type; + ck_tile::index_t batch; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_kv; + ck_tile::index_t hdim; + float softmax_scale; + mask_info mask; + TensorLayout input_layout; + TensorLayout output_layout; +}; + +struct RunConfig +{ + explicit RunConfig(const ck_tile::ArgParser& args) + { + seed = args.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + kernel_warmup = args.get_int("warmup"); + kernel_repeat = args.get_int("repeat"); + verify = args.get_bool("v"); + } + + std::optional seed; + int kernel_warmup; + int kernel_repeat; + bool verify; +}; + +template +auto generate_qkv(const Problem& problem, + [[maybe_unused]] std::optional seed = std::nullopt) + -> std::tuple, + ck_tile::HostTensor, + ck_tile::HostTensor> +{ + ck_tile::HostTensor q(problem.get_query_shape()); + ck_tile::HostTensor k(problem.get_key_shape()); + ck_tile::HostTensor v(problem.get_value_shape()); + + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v); + + return std::make_tuple(q, k, v); +} + +namespace host { +template +CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor& q_bshd, + const ck_tile::HostTensor& k_bshd, + const ck_tile::HostTensor& v_bshd, + const mask_info& mask, + ck_tile::HostTensor& o_bshd, + const QElementOp& q_element_op = {}, + const KElementOp& k_element_op = {}, + const VElementOp& v_element_op = {}, + const SAccElementOp& s_acc_element_op = {}) +{ + const int batch_size = q_bshd.mDesc.get_lengths()[0]; + const int seqlen_q = q_bshd.mDesc.get_lengths()[1]; + const int seqlen_kv = k_bshd.mDesc.get_lengths()[1]; + const int nhead_q = q_bshd.mDesc.get_lengths()[2]; + const int nhead_kv = k_bshd.mDesc.get_lengths()[2]; + const int hdim_qk = q_bshd.mDesc.get_lengths()[3]; + const int hdim_v = v_bshd.mDesc.get_lengths()[3]; + + const int nr = nhead_q / nhead_kv; + + ck_tile::HostTensor q_host_ref({nhead_q, seqlen_q, hdim_qk}); + ck_tile::HostTensor k_host_ref({nhead_q, seqlen_kv, hdim_qk}); + ck_tile::HostTensor v_host_ref({nhead_q, hdim_v, seqlen_kv}); + ck_tile::HostTensor o_host_ref({nhead_q, seqlen_q, hdim_v}); + + ck_tile::HostTensor s_host_ref({nhead_q, seqlen_q, seqlen_kv}); + ck_tile::HostTensor p_host_ref({nhead_q, seqlen_q, seqlen_kv}); + + // do computation for each batch + for(int b = 0; b < batch_size; ++b) + { + // copy per-batch data from input tensors + // clang-format off + q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); }); + k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); }); + v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); }); + // clang-format on + ck_tile::reference_batched_gemm( + q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op); + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, seqlen_q, seqlen_kv)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + seqlen_q, + seqlen_kv, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + seqlen_q, + seqlen_kv, + mask.type == mask_enum::mask_top_left)); + } + + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, ck_tile::identity{}); + + ck_tile::reference_batched_gemm( + p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op); + + // copy resulting per-batch data to the output tensor + o_host_ref.ForEach( + [&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); }); + } +} +} // namespace host + +template +bool run_impl(const Problem& problem, const RunConfig& run_config) +{ + auto [q, k, v] = generate_qkv(problem, run_config.seed); + + ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes()); + /// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v + ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q.data()); + k_buf.ToDevice(k.data()); + v_buf.ToDevice(v.data()); + + ck_tile::fmha_fwd_v3_args args; + + args.data_type = problem.data_type; + args.batch = problem.batch; + args.seqlen_q = problem.seqlen_q; + args.seqlen_k = problem.seqlen_k; + args.nhead_q = problem.nhead_q; + args.nhead_kv = problem.nhead_kv; + args.hdim_qk = problem.hdim; + args.hdim_v = problem.hdim; + args.softmax_scale = problem.softmax_scale; + + args.window_size_left = problem.mask.left; + args.window_size_right = problem.mask.right; + args.mask_type = static_cast(problem.mask.type); + + // bshd: (batch, seqlen_q, nhead_q, hdim) + // bhsd: (batch, nhead_q, seqlen_q, hdim) + args.q_ptr = q_buf.GetDeviceBuffer(); + args.stride_q = + problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; + args.nhead_stride_q = + problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim; + args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim; + + // bshd: (batch, seqlen_k, nhead_kv, hdim) + // bhsd: (batch, nhead_kv, seqlen_k, hdim) + args.k_ptr = k_buf.GetDeviceBuffer(); + args.stride_k = + problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; + args.nhead_stride_k = + problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; + args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim; + + // bshd: (batch, seqlen_k, nhead_kv, hdim) + // bhsd: (batch, nhead_kv, seqlen_k, hdim) + args.v_ptr = v_buf.GetDeviceBuffer(); + args.stride_v = + problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim; + args.nhead_stride_v = + problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim; + args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim; + + // bshd: (batch, seqlen_q, nhead_q, hdim) + // bhsd: (batch, nhead_q, seqlen_q, hdim) + args.o_ptr = o_buf.GetDeviceBuffer(); + args.stride_o = + problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim; + args.nhead_stride_o = problem.output_layout == TensorLayout::bshd + ? problem.hdim + : problem.seqlen_q * problem.hdim; + args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; + + ck_tile::stream_config stream_config{nullptr, + true, + /*log_level=*/0, + run_config.kernel_warmup, + run_config.kernel_repeat}; + + auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config); + if(!result) + { + std::cerr << "faild to run fmha_fwd_v3()" << std::endl; + return false; + } + + std::size_t flop = [&] { + if(problem.mask.type == mask_enum::no_mask) + { + return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * + problem.hdim; + } + else + { + /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2. + return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k * + problem.hdim; + } + }(); + float tflops = static_cast(flop) / 1.e9 / time; + + std::cout << "[" << problem.data_type << "|"; + if(problem.input_layout == problem.output_layout) + { + std::cout << problem.input_layout; + } + else + { + std::cout << problem.input_layout << "-" << problem.output_layout; + } + std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv + << ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim + << ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed + << ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops + << " TFlops" << std::endl; + + if(!run_config.verify) + { + return true; + } + + // transpose tensor descriptors from bhsd to bshd if necessary + if(problem.input_layout != TensorLayout::bshd) + { + q = q.transpose({0, 2, 1, 3}); + k = k.transpose({0, 2, 1, 3}); + v = v.transpose({0, 2, 1, 3}); + } + + ck_tile::HostTensor o_ref(problem.get_output_shape()); + if(problem.output_layout != TensorLayout::bshd) + { + o_ref = o_ref.transpose({0, 2, 1, 3}); + } + + host::fmha_fwd(q, + k, + v, + problem.mask, + o_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + + ck_tile::HostTensor o(problem.get_output_shape()); + o_buf.FromDevice(o.data()); + + const auto [rtol, atol] = [&] { + if constexpr(std::is_same_v) + return std::make_tuple(1e-3, 1e-3); + else + return std::make_tuple(1e-2, 1e-2); + }(); + return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol); +} + +int main(int argc, char* argv[]) +{ + auto [parse_result, args] = parse_cmd_args(argc, argv); + if(!parse_result) + { + std::cerr << "failed to parse command line arguments" << std::endl; + } + + Problem problem(args); + RunConfig run_config(args); + + const auto run = [&] { + if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16) + { + return run_impl(problem, run_config); + } + else + { + return run_impl(problem, run_config); + } + }; + + return !run(); +} diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp new file mode 100644 index 0000000000..30019167fb --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd_v3.hpp" +#include "fmha_fwd_v3_impl.hpp" +#include "mask.hpp" + +namespace ck_tile { + +std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type) +{ + switch(data_type) + { + case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16"; + case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16"; + default: return stream << "unknown"; + } +} + +std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config) +{ + if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16) + { + if(args.mask_type == static_cast(mask_enum::no_mask)) + { + using kernel_traits = + fmha_fwd_v3_kernel_traits; + + return fmha_fwd_v3_kernel_dispatch(args, config); + } + else + { + using kernel_traits = + fmha_fwd_v3_kernel_traits; + + return fmha_fwd_v3_kernel_dispatch(args, config); + } + } + else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16) + { + if(args.mask_type == static_cast(mask_enum::no_mask)) + { + using kernel_traits = + fmha_fwd_v3_kernel_traits; + + return fmha_fwd_v3_kernel_dispatch(args, config); + } + else + { + using kernel_traits = + fmha_fwd_v3_kernel_traits; + + return fmha_fwd_v3_kernel_dispatch(args, config); + } + } + + return std::make_pair(false, -1.f); +} + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp new file mode 100644 index 0000000000..5361d27f0f --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/host/stream_config.hpp" + +namespace ck_tile { + +struct fmha_fwd_v3_args +{ + enum class data_type_enum + { + fp16, + bf16 + }; + + data_type_enum data_type; + // bool is_varlen; + + index_t batch; + index_t seqlen_q; + index_t seqlen_k; + index_t nhead_q; + index_t nhead_kv; + index_t hdim_qk; + index_t hdim_v; + + float softmax_scale; + + index_t window_size_left; + index_t window_size_right; + index_t mask_type; + + const void* q_ptr; + index_t stride_q; + index_t nhead_stride_q; + index_t batch_stride_q; + + const void* k_ptr; + index_t stride_k; + index_t nhead_stride_k; + index_t batch_stride_k; + + const void* v_ptr; + index_t stride_v; + index_t nhead_stride_v; + index_t batch_stride_v; + + void* o_ptr; + index_t stride_o; + index_t nhead_stride_o; + index_t batch_stride_o; +}; + +std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); + +// return value: +// first = whether the kernel was launched (true = launched, false = skipped) +// second = elapsed time (ms) of the kernel launch, valid only if first == true +std::pair fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config); + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp new file mode 100644 index 0000000000..d6e4ac4c60 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" + +#include "fmha_fwd_v3.hpp" + +#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ + template <> \ + std::pair fmha_fwd_v3_kernel_dispatch( \ + const fmha_fwd_v3_args& args, const stream_config& config) \ + { \ + return std::make_pair(true, \ + fmha_fwd_v3_kernel_launch(args, config)); \ + } + +namespace ck_tile { + +template +struct fmha_fwd_v3_problem_traits; + +template <> +struct fmha_fwd_v3_problem_traits +{ + using qkvp_dtype = ck_tile::half_t; + using acc_dtype = float; + using o_dtype = ck_tile::half_t; + using lse_dtype = float; +}; + +template <> +struct fmha_fwd_v3_problem_traits +{ + using qkvp_dtype = ck_tile::bf16_t; + using acc_dtype = float; + using o_dtype = ck_tile::bf16_t; + using lse_dtype = float; +}; + +template +struct fmha_fwd_v3_kernel_traits +{ + static constexpr auto date_type = DataType; + static constexpr bool is_variable_seqlen = IsVariableSeqlen; + static constexpr bool is_masking = IsMasking; + + // M0 N0 K0 N1 K1 + using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>; + using fmha_warp_gemm_shape = sequence<32, 32, 16>; + using fmha_block_warps = sequence<8, 1, 1>; + + using fmha_shape = TileFmhaShape; + + using fmha_traits = TileFmhaFwdV3Traits; + + using fmha_mask = SimplifiedGenericAttentionMask; + + using fmha_pipeline_problem = + BlockFmhaFwdV3PipelineProblem::qkvp_dtype, + typename fmha_fwd_v3_problem_traits::qkvp_dtype, + typename fmha_fwd_v3_problem_traits::qkvp_dtype, + typename fmha_fwd_v3_problem_traits::acc_dtype, + typename fmha_fwd_v3_problem_traits::acc_dtype, + typename fmha_fwd_v3_problem_traits::lse_dtype, + typename fmha_fwd_v3_problem_traits::qkvp_dtype, + typename fmha_fwd_v3_problem_traits::acc_dtype, + typename fmha_fwd_v3_problem_traits::o_dtype, + fmha_shape, + IsVariableSeqlen, + fmha_mask, + fmha_traits>; + + using fmha_pipeline = BlockFmhaFwdV3Pipeline; + + using epilogue = Default2DEpilogue< + Default2DEpilogueProblem::acc_dtype, + typename fmha_fwd_v3_problem_traits::o_dtype, + true, // kPadM + true, // kPadM + true // UseRawStore + >>; + + using kernel = FmhaFwdV3Kernel; +}; + +template +float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config) +{ + auto kargs = Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + nullptr, // lse_ptr + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_qk, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_kv, + args.softmax_scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + 0, // nhead_stride_lse + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + 0, // batch_stride_lse + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr index_t kBlockPerCu = Kernel::kBlockPerCu; + + return launch_kernel(config, make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} + +// return value: +// first = whether the kernel was launched (true = launched, false = skipped) +// second = elapsed time (ms) of the kernel launch, valid only if first == true +template +std::pair fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args, + const stream_config& config); + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp new file mode 100644 index 0000000000..2dbe0b2098 --- /dev/null +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_mask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd_v3.hpp" +#include "fmha_fwd_v3_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + fmha_fwd_v3_kernel_traits; + +INST_FMHA_FWD_V3_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp new file mode 100644 index 0000000000..6f5eca97a1 --- /dev/null +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_bf16_nmask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd_v3.hpp" +#include "fmha_fwd_v3_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + fmha_fwd_v3_kernel_traits; + +INST_FMHA_FWD_V3_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp new file mode 100644 index 0000000000..1c4c798af6 --- /dev/null +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_mask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd_v3.hpp" +#include "fmha_fwd_v3_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + fmha_fwd_v3_kernel_traits; + +INST_FMHA_FWD_V3_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp new file mode 100644 index 0000000000..077cb7b73c --- /dev/null +++ b/example/ck_tile/01_fmha/instances/fmha_fwd_v3_d128_fp16_nmask.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd_v3.hpp" +#include "fmha_fwd_v3_impl.hpp" + +namespace ck_tile { + +using kernel_traits = + fmha_fwd_v3_kernel_traits; + +INST_FMHA_FWD_V3_DISPATCH(kernel_traits) + +} // namespace ck_tile diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh new file mode 100755 index 0000000000..9c500edf9d --- /dev/null +++ b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh @@ -0,0 +1,31 @@ +#!/bin/sh +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_fwd_v3 -type f | head -n 1)" +VALID=0 + +for causal in 0 1 ; do +for prec in "fp16" "bf16" ; do +for hdim in 128 ; do +for perm in 0 ; do + +if [ $causal -eq 0 ]; then + mask=0 +else + mask=b:-1,0 +fi + +$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID + +$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID + +done +done +done +done diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 16fde15c7b..31de21a726 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -18,6 +18,7 @@ #include "ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" @@ -40,6 +41,8 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp new file mode 100644 index 0000000000..be14a36353 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -0,0 +1,519 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" + +#include +#include + +namespace ck_tile { + +template +struct FmhaFwdV3Kernel +{ + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + }; + + struct FmhaFwdMaskKargs + { + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaFwdCommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; + }; + + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + }; + + struct FmhaFwdGroupModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + static_cast(scale_s * ck_tile::log2e_v<>), + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for mask + {}, // placeholder for lse + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + static_cast(scale_s * ck_tile::log2e_v<>), + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for mask + {}, // placeholder for lse + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + } + + return kargs; + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + nhead_, + batch_size_); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + using namespace ck_tile; + + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + using namespace ck_tile; + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + batch_offset_v = key_start * kargs.stride_v; + + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto q_dram_window = make_tile_window( + q_dram, + make_tuple(number{}, number{}), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {0, i_n1}); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + auto o_acc_tile = [&]() { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lse_dram_window, + mask, + kargs.scale_s, + smem_ptr); + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp new file mode 100644 index 0000000000..20d84116d4 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -0,0 +1,1198 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +#define ENABLE_ASM_MARKER 1 +#if ENABLE_ASM_MARKER +#define ASM_MARKER(marker) \ + __builtin_amdgcn_sched_barrier(0); \ + asm volatile("; [POYENC] " #marker); \ + __builtin_amdgcn_sched_barrier(0); +#else +#define ASM_MARKER(marker) +#endif + +#define ADD_SBARRIER_FOR_PHASE0 1 +#if !defined(CK_TILE_DISABLE_PACKED_FP32) +#define CK_TILE_DISABLE_PACKED_FP32 0 +#endif + +#define WARP_ID 0 +#define LANE_ID 0 + +#define ENABLE_DEBUG_STMTS 1 +#if ENABLE_DEBUG_STMTS +#define DEBUG_STMTS \ + if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID) +#else +#define DEBUG_STMTS if constexpr(false) +#endif + +namespace ck_tile { + +template +struct CoreLoopScheduler; + +template +struct CoreLoopScheduler +{ + template + CK_TILE_DEVICE static constexpr void schedule(ck_tile::number, + ck_tile::number) + { + using namespace ck_tile; + + if constexpr(WaveGroup == 0) + { + if constexpr(Phase == 0) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 1) {} + else if constexpr(Phase == 2) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + else if constexpr(Phase == 3) {} + } + else + { + if constexpr(Phase == 0) {} + else if constexpr(Phase == 1) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 2) {} + else if constexpr(Phase == 3) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + } + } +}; + +template +struct CoreLoopScheduler +{ + template + CK_TILE_DEVICE static constexpr void schedule(ck_tile::number, + ck_tile::number) + { + using namespace ck_tile; + + if constexpr(WaveGroup == 0) + { + if constexpr(Phase == 0) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 1) {} + else if constexpr(Phase == 2) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + else if constexpr(Phase == 3) {} + } + else + { + if constexpr(Phase == 0) {} + else if constexpr(Phase == 1) + { + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + }); + } + else if constexpr(Phase == 2) {} + else if constexpr(Phase == 3) + { +#if !CK_TILE_DISABLE_PACKED_FP32 + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU +#endif + static_for<0, 8, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU + }); + } + } + } +}; + +namespace detail { +CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c) +{ +#if CK_TILE_DISABLE_PACKED_FP32 + return a * b + c; +#else + float result; + asm volatile("v_fma_f32 %[result], %[a], %[b], %[c]" + : [result] "=v"(result) + : [a] "v"(a), [b] "s"(b), [c] "v"(c)); + return result; +#endif +} + +CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs) +{ + float result; + asm volatile("v_add_f32_e32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} + +CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b) +{ + fp16x2_t result; + asm volatile("v_cvt_pk_f16_f32 %[result], %[a], %[b]" + : [result] "=v"(result) + : [a] "v"(a), [b] "v"(b)); + return result; +} + +CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b) +{ + bf16x2_t result; + asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]" + : [result] "=v"(result) + : [a] "v"(a), [b] "v"(b)); + return result; +} + +CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs) +{ + fp32x2_t result; + asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} +} // namespace detail + +template +struct BlockFmhaFwdV3Pipeline +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + using SMPLComputeDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; + using OaccDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + + static_assert(std::is_same_v, + "we will the same dist tensor 'sp_compute' for both gemm0 & softmax"); + + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize; + + static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; + static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; + static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; + static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr ck_tile::index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr ck_tile::index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr ck_tile::index_t kAlignmentV = + kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + + static constexpr ck_tile::index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + + static constexpr ck_tile::index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + return 2; + } + }(); + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + // create another LDS buffer for p + return ck_tile::max(kM0 * kN1 * sizeof(PDataType), + Policy::template GetSmemSize() + + kM0 * kN0 * sizeof(PDataType)); + } + + // for debug only + template + CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc() + { + using namespace ck_tile; + constexpr auto lds_block_desc = + make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number<1>{}, + number<1>{}); + + return lds_block_desc; + } + + // for debug only + template + CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc1D() + { + using namespace ck_tile; + constexpr auto lds_block_desc = make_naive_tensor_descriptor( + make_tuple(number{}), make_tuple(number<1>{}), number<1>{}, number<1>{}); + + return lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr auto make_lds_tile_window(void* base, const Descriptor& desc) + { + using namespace ck_tile; + + auto tensor_view = + make_tensor_view(reinterpret_cast(base), desc); + return make_tile_window(tensor_view, desc.get_lengths(), {0, 0}); + } + + // vmcnt=0~63, lgkmcnt=0~15, expcnt=0~7 + template + CK_TILE_DEVICE static constexpr void s_waitcnt() + { + // vmcnt use bits {[15:14],[3:0]} + // expcnt use bits [6:4] + // lgkmcnt use bits [11:8] + __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) | + ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8)); + } + + template + CK_TILE_DEVICE static constexpr void s_waitcnt_vmcnt() + { + s_waitcnt(); + } + + template + CK_TILE_DEVICE static constexpr void s_waitcnt_lgkmcnt() + { + s_waitcnt<63, Lgkmcnt>(); + } + + template + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + [[maybe_unused]] const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + [[maybe_unused]] const VElementFunction& v_element_func, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + [[maybe_unused]] const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + float scale_s, + void* smem_ptr) const + { + using namespace ck_tile; + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize()); + auto s_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr)), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto s_lds_window = + make_tile_window(s_lds, make_tuple(number{}, number{}), {0, 0}); + + auto p_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSize()), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto p_lds_window = + make_tile_window(p_lds, make_tuple(number{}, number{}), {0, 0}); + + auto o_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr)), + MakeSimpleLdsDesc()); + [[maybe_unused]] auto o_lds_window = + make_tile_window(o_lds, make_tuple(number{}, number{}), {0, 0}); + + auto m_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetSmemSize()), + MakeSimpleLdsDesc1D()); + [[maybe_unused]] auto m_lds_window = + make_tile_window(m_lds, make_tuple(number{}), {0}); + + const index_t warp_group_id = get_warp_id() / 4; + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); + + auto q_dram_window = make_tile_window_linear( + q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution()); + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + auto k_lds_window_store = generate_tuple( + [&](auto i_buf) { + return make_lds_tile_window( + smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)); + }, + number<2>{}); + + auto v_lds_window_store = generate_tuple( + [&](auto i_buf) { + return make_lds_tile_window( + smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor(i_buf)); + }, + number<2>{}); + + statically_indexed_array( + nullptr, + Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKRegTileDistribution())), + 2> + k_lds_window_load; + + statically_indexed_array( + nullptr, + Policy::template MakeVLdsLoadBlockDescriptor()), + Policy::template MakeVRegTileDistribution())), + 2> + v_lds_window_load; + + decltype(make_static_distributed_tensor( + Policy::template MakeQRegTileDistribution())) q_tile; + + union kv_tile_type + { + CK_TILE_DEVICE kv_tile_type() {} + + decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile; + + decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile; + } kv_tile; + + union sp_compute_type + { + CK_TILE_DEVICE sp_compute_type() {} + + decltype(gemm_0.MakeCBlockTile()) sp_compute; + decltype(make_static_distributed_tensor( + Policy::template MakePRegTileDistribution())) p; + }; + statically_indexed_array sp; + + decltype(gemm_1.MakeCBlockTile()) o_acc; + constexpr index_t fmha_alu_D_reg_cnt = 0; // threshold to decide how many fmha_alu_D_upd() + // instructions should we move to fmha_alu1() + static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size()); + + decltype(block_tile_reduce( + sp(number<0>{}).sp_compute, sequence<1>{}, f_max, SMPLComputeDataType{0})) m; + decltype(m) l; + + // initialize k_lds_window and v_lds_window + static_for<0, 2, 1>{}([&](auto idx) { + k_lds_window_load(idx) = make_tile_window( + make_lds_tile_window( + static_cast(smem_ptr) + (idx)*Policy::template GetSmemSizeKV(), + Policy::template MakeKLdsLoadBlockDescriptor()), + Policy::template MakeKRegTileDistribution()); + }); + + static_for<0, 2, 1>{}([&](auto idx) { + v_lds_window_load(idx) = + make_tile_window(make_lds_tile_window( + static_cast(smem_ptr) + + (idx + 2) * Policy::template GetSmemSizeKV(), + Policy::template MakeVLdsLoadBlockDescriptor()), + Policy::template MakeVRegTileDistribution()); + }); + + { + auto origin_q = load_tile(q_dram_window); + auto transformed_q = tile_elementwise_in(q_element_func, origin_q); + + q_tile = transformed_q; + } + + clear_tile(o_acc); + set_tile(m, bit_cast(0xff7fffff)); // a bit larger than -infinity + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + index_t kv_token_start = seqlen_k_start; + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -numeric::infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); + k_dram_window.init_raw(); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + v_dram_window.init_raw(); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + static_assert(1 == k0_loops); + static_assert(1 == k1_loops); + static_assert(kN0 == kK1); + + constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup; + static_assert(NumWarpGroups == 2); + + [[maybe_unused]] auto print_dist_tensor = [&](const auto& dist_tensor, const char* name) { + printf("[POYENC] %s (size=%d): %5.2f", + name, + decltype(dist_tensor.thread_buf_)::size(), + ck_tile::type_convert(dist_tensor.thread_buf_[0])); + static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](auto i) { + printf(", %5.2f", ck_tile::type_convert(dist_tensor.thread_buf_[i])); + }); + printf("\n"); + }; + + [[maybe_unused]] auto print_lds = [&](auto lds_tile_window, const char* name) { + const auto num_rows = lds_tile_window.get_window_lengths().at(number<0>{}); + const auto num_cols = lds_tile_window.get_window_lengths().at(number<1>{}); + + auto desc = lds_tile_window.get_bottom_tensor_view().desc_; + auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_; + + if constexpr(true || num_rows < num_cols) + { + for(int row = 0; row < num_rows; ++row) + { + int offset = desc.calculate_offset(make_tuple(row, 0)); + printf("[DEVICE] %s[%3d] = %5.2f", + name, + row, + ck_tile::type_convert(data[offset])); + for(int col = 1; col < num_cols; ++col) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(row, col)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + } + } + else + { + for(int col = 0; col < num_cols; ++col) + { + int offset = desc.calculate_offset(make_tuple(0, col)); + printf("[DEVICE] %s[%3d] = %5.2f", + name, + col, + ck_tile::type_convert(data[offset])); + for(int row = 1; row < num_rows; ++row) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(row, col)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + } + } + }; + + [[maybe_unused]] auto print_lds_1d = [&](auto lds_tile_window, const char* name) { + const auto num_elems = lds_tile_window.get_window_lengths().at(number<0>{}); + + auto desc = lds_tile_window.get_bottom_tensor_view().desc_; + auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_; + + int offset = desc.calculate_offset(make_tuple(0)); + printf("[DEVICE] %s = %5.2f", name, ck_tile::type_convert(data[offset])); + for(int e = 1; e < num_elems; ++e) + { + printf(", "); + offset = desc.calculate_offset(make_tuple(e)); + printf("%5.2f", ck_tile::type_convert(data[offset])); + } + printf("\n"); + }; + + // K_mem_su_ld_insts = 1 for 32 x 128 + // V_mem_su_ld_insts = 1 for 128 x 32 + static constexpr int K_mem_su_ld_insts = 1; + static constexpr int V_mem_su_ld_insts = 1; + + auto K_mem_load = [&](auto k_lds_write_idx) { + async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); + + /// FIXME: use the future-predicting method to move the window + // move K tile windows + move_tile_window(k_dram_window, {kN0, 0}); + }; + + auto K_lds_load = [&](auto k_lds_read_idx) { + kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx)); + }; + + auto V_mem_load = [&](auto v_lds_write_idx) { + async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); + __builtin_amdgcn_sched_barrier(0); + + /// FIXME: use the future-predicting method to move the window + move_tile_window(v_dram_window, {kK1, 0}); + }; + + auto V_lds_load = [&](auto v_lds_read_idx) { + kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx)); + }; + + decltype(m) m_old; + SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd() + /// TODO: remove the sp_delta and use sp_compute directly + statically_indexed_array{}).sp_compute), 2> sp_delta; + + auto fmha_alu0 = [&](auto sp_reg_idx) { + m_old = m; // m{j-1} + static_assert(m.thread_buf_.size() == 1, + "assuming that each thread holds 1 rowmax value"); + auto m_latest = block_tile_reduce( + sp(sp_reg_idx).sp_compute, sequence<1>{}, f_max, m.thread_buf_[0]); +#if defined(__gfx950__) + // assuming that we are using 32x32 mfma + int32x2_t swapped_regs = + __builtin_amdgcn_permlane32_swap(bit_cast(m_latest.thread_buf_[0]), + bit_cast(m_latest.thread_buf_[0]), + false, + false); + /// TODO: eliminate 2 redudant v_max_f32 instructions generated by the compiler + m_latest.thread_buf_[0] = f_max(bit_cast(swapped_regs.x), + bit_cast(swapped_regs.y)); +#else + block_tile_reduce_sync(m_latest, f_max, bool_constant{}); +#endif + m = m_latest; + + constexpr auto p_spans = + std::decay_t::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv( + sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx)); + }); + }); + /// TODO: move some fmha_alu1() code here if necessary + }; + + auto fmha_alu1 = [&](auto sp_reg_idx) { + constexpr auto p_spans = + std::decay_t::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + sp(sp_reg_idx).sp_compute(i_j_idx) = + ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx)); + }); + }); + + auto rowsum_p = block_tile_reduce( + sp(sp_reg_idx).sp_compute, + sequence<1>{}, + f_sum, + SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + static_assert(rowsum_p.thread_buf_.size() == 1, + "assuming that each thread holds 1 rowsum value"); +#if defined(__gfx950__) + // assuming that we are using 32x32 mfma + int32x2_t swapped_regs = + __builtin_amdgcn_permlane32_swap(bit_cast(rowsum_p.thread_buf_[0]), + bit_cast(rowsum_p.thread_buf_[0]), + false, + false); + rowsum_p.thread_buf_[0] = f_sum(bit_cast(swapped_regs.x), + bit_cast(swapped_regs.y)); +#else + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); +#endif + // update partial o_acc [0, 2) + static_for<0, ck_tile::min(2, fmha_alu_D_reg_cnt), 1>{}( + [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); + + // l{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx])); + + l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]); + }); + + // update partial o_acc [2, fmha_alu_D_reg_cnt) + static_for<2, ck_tile::max(2, fmha_alu_D_reg_cnt), 1>{}( + [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); + + /// NOTICE: Compiler keep moving the conversion instructions to other places. We rewite + /// the cast_tile() call into inline asm to force the conversion instructions to be + /// generated here. The fmha_alu1() call should be placed at the end of a phase. + static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0); + static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) { + float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]); + float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]); + if constexpr(std::is_same_v) + { + auto casted = detail::cvt_pk_fp16_f32(x, y); + sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; + sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; + } + else + { + auto casted = detail::cvt_pk_bf16_f32(x, y); + sp(sp_reg_idx).p.thread_buf_[idx] = casted.x; + sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; + } + }); + }; + + auto gemm = [&](auto sp_reg_idx, auto gemm_idx) { + if constexpr(gemm_idx == 0) + { + clear_tile(sp(sp_reg_idx).sp_compute); // initialize C + gemm_0(sp(sp_reg_idx).sp_compute, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + get_slice_tile(kv_tile.k_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + else + { + gemm_1(o_acc, + get_slice_tile(sp(sp_reg_idx).p, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{}), + get_slice_tile(kv_tile.v_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{})); + } + }; + + auto cl_calc = [&](auto sp_reg_idx, auto gemm_idx) { + if constexpr(gemm_idx == 0) + { + clear_tile(sp(sp_reg_idx).sp_compute); // initialize C + gemm_0(sp(sp_reg_idx).sp_compute, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + get_slice_tile(kv_tile.k_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{})); + } + else + { + gemm_1(o_acc, + get_slice_tile(sp(sp_reg_idx).p, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{}), + get_slice_tile(kv_tile.v_tile, + sequence<0, (k1_loops - 1) * kK1>{}, + sequence{})); + fmha_alu0(number<1>{} - sp_reg_idx); + } + }; + + auto fmha_alu_D_upd = [&] { + o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0])); + + fp32x2_t pk_o_acc_scale; + pk_o_acc_scale.x = o_acc_scale; + pk_o_acc_scale.y = o_acc_scale; + + static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0); +#if CK_TILE_DISABLE_PACKED_FP32 + static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size()); + static_for{}( + [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); +#endif + + constexpr auto issued_D_reg_cnt = +#if CK_TILE_DISABLE_PACKED_FP32 + fmha_alu_D_reg_cnt + 2 +#else + fmha_alu_D_reg_cnt +#endif + ; + /// NOTICE: Use inline asm v_pk_mul_f32 to reduce latency. The fmha_alu_D_upd() call + /// should be placed at the end of a phase. + // update partial o_acc after [issued_D_reg_cnt] + static_for{}([&](auto idx) { + fp32x2_t input; + input.x = o_acc.thread_buf_[idx]; + input.y = o_acc.thread_buf_[idx + 1]; + + auto output = detail::pk_mul_f32(input, pk_o_acc_scale); + + o_acc.thread_buf_[idx] = output.x; + o_acc.thread_buf_[idx + 1] = output.y; + }); + }; + + auto fmha_mask = [&](auto sp_reg_idx) { + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + bool need_perpixel_check = mask.IsEdgeTile( + q_origin.at(number<0>{}), kv_token_start, number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(sp(sp_reg_idx).sp_compute, + -numeric::infinity(), + [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = kv_token_start + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + }; + + auto cl_load = [&](auto load_type, auto mem_wr_idx, auto lds_rd_idx) { + if constexpr(load_type == 0) + { + V_mem_load(mem_wr_idx); + K_lds_load(lds_rd_idx); + } + else + { + K_mem_load(mem_wr_idx); + V_lds_load(lds_rd_idx); + } + }; + + auto core_loop = [&](auto cl_p) { + auto gemm0 = number<0>{}; + auto gemm1 = number<1>{}; + + auto memV = number<0>{}; + auto memK = number<1>{}; + + using Scheduler = CoreLoopScheduler; + + auto iteration = [&](auto pi) { + auto xdl_SP_p01_reg_idx = number<1>{} - pi; + auto xdl_SP_p23_reg_idx = pi; + + auto K_w0_lds_wr_idx = number<1>{} - pi; + auto V_w0_lds_wr_idx = pi; + auto K_w0_lds_rd_idx = pi; + auto V_w0_lds_rd_idx = pi; + + auto K_w4_lds_wr_idx = number<1>{} - pi; + auto V_w4_lds_wr_idx = number<1>{} - pi; + auto K_w4_lds_rd_idx = number<1>{} - pi; + auto V_w4_lds_rd_idx = pi; + + bool result = true; + + if constexpr(cl_p == 0) + { +#if ADD_SBARRIER_FOR_PHASE0 + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); +#endif + __builtin_amdgcn_sched_barrier(0); + // phase0 + if constexpr(pi == 0) + { + ASM_MARKER("phase0 Wave0-3 (pi=0)"); + } + else + { + ASM_MARKER("phase0 Wave0-3 (pi=1)"); + } + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p01_reg_idx, gemm0); + fmha_alu1(xdl_SP_p23_reg_idx); + + Scheduler::schedule(cl_p, number<0>{}); + __builtin_amdgcn_sched_barrier(0); + // phase1 + ASM_MARKER("phase1 Wave0-3"); + s_waitcnt_vmcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); + fmha_mask(xdl_SP_p01_reg_idx); + + Scheduler::schedule(cl_p, number<1>{}); + __builtin_amdgcn_sched_barrier(0); + // phase2 + ASM_MARKER("phase2 Wave0-3"); + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p23_reg_idx, gemm1); + + Scheduler::schedule(cl_p, number<2>{}); + __builtin_amdgcn_sched_barrier(0); + fmha_alu_D_upd(); + + __builtin_amdgcn_sched_barrier(0); + // phase3 + ASM_MARKER("phase3 Wave0-3"); + s_waitcnt_vmcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx); + + Scheduler::schedule(cl_p, number<3>{}); + kv_token_start += kN0; + if(num_total_loop <= ++i_total_loops) + { + result = false; + } + } + else + { +#if ADD_SBARRIER_FOR_PHASE0 + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); +#endif + __builtin_amdgcn_sched_barrier(0); + // phase0 + if constexpr(pi == 0) + { + ASM_MARKER("phase0 Wave4-7 (pi=0)"); + } + else + { + ASM_MARKER("phase0 Wave4-7 (pi=1)"); + } + cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx); + + Scheduler::schedule(cl_p, number<0>{}); + __builtin_amdgcn_sched_barrier(0); + // phase1 + ASM_MARKER("phase1 Wave4-7"); + s_waitcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p01_reg_idx, gemm0); + fmha_alu1(xdl_SP_p23_reg_idx); + + Scheduler::schedule(cl_p, number<1>{}); + __builtin_amdgcn_sched_barrier(0); + // phase2 + ASM_MARKER("phase2 Wave4-7"); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx); + fmha_mask(xdl_SP_p01_reg_idx); + + Scheduler::schedule(cl_p, number<2>{}); + kv_token_start += kN0; + if(num_total_loop <= ++i_total_loops) + { + result = false; + } + + __builtin_amdgcn_sched_barrier(0); + // phase3 + ASM_MARKER("phase3 Wave4-7"); + s_waitcnt(); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + cl_calc(xdl_SP_p23_reg_idx, gemm1); + + Scheduler::schedule(cl_p, number<3>{}); + __builtin_amdgcn_sched_barrier(0); + fmha_alu_D_upd(); + } + return result; + }; + return iteration(number<0>{}) && iteration(number<1>{}); + }; + + auto fmha_post_process = [&](auto d) { + auto ps_pi = number<1>{} - d; + auto V_lds_rd_idx = ps_pi; + + s_waitcnt_vmcnt(); + __builtin_amdgcn_s_barrier(); + + V_lds_load(V_lds_rd_idx); + fmha_alu1(ps_pi); + + s_waitcnt_lgkmcnt<0>(); + + auto xdl_SP_p23_reg_idx = ps_pi; + gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{}); + }; + + // pre-stage + { + ASM_MARKER("before pre-stage"); + // (1) load K0 to LDS & VGPR + K_mem_load(number<0>{}); // mem_K0 + + s_waitcnt_vmcnt<0>(); + __builtin_amdgcn_s_barrier(); + + K_lds_load(number<0>{}); // lds_K0 + + s_waitcnt_lgkmcnt<0>(); + __builtin_amdgcn_s_barrier(); + + // (2) prefetch K1 and V0 to LDS in parallel with GEMM0 + if(1 < num_total_loop) + { + K_mem_load(number<1>{}); // mem_K1 + } + V_mem_load(number<0>{}); // mem_V0 + + // (3) mfma (Q*K0) + softmax + gemm(number<0>{}, /*gemm_idx=*/number<0>{}); + + fmha_mask(number<0>{}); + /// TODO: find better way to map fmha_alu(0,96) call + fmha_alu0(number<0>{}); + fmha_alu_D_upd(); + + kv_token_start += kN0; + ++i_total_loops; + if(num_total_loop <= i_total_loops) + { + goto label_main_loops_exit; + } + + if(2 < num_total_loop) + { + K_mem_load(number<0>{}); // mem_K2 + + s_waitcnt_vmcnt(); + __builtin_amdgcn_s_barrier(); + } + + ASM_MARKER("end pre-stage"); + } + + if(1 < num_total_loop) + { + if(warp_group_id == 0) + { + V_mem_load(number<1>{}); // V1 + K_lds_load(number<1>{}); // K1 + + asm volatile("s_setprio 0"); + __builtin_amdgcn_s_barrier(); + while(core_loop(number<0>{})) + ; + } + if(warp_group_id != 0) + { + asm volatile("s_setprio 1"); + __builtin_amdgcn_s_barrier(); + while(core_loop(number<1>{})) + ; + } + } + label_main_loops_exit: + if(num_total_loop % 2) + { + fmha_post_process(number<1>{}); + } + if(!(num_total_loop % 2)) + { + fmha_post_process(number<0>{}); + } + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + lse(i_idx) = m[i_idx] / C_LOG2E + log(l[i_idx]); + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale_s, + void* smem_ptr) const + { + using namespace ck_tile; + + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + scale_s, + smem_ptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp new file mode 100644 index 0000000000..e440280d7e --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -0,0 +1,603 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" + +namespace ck_tile { + +struct BlockFmhaV3PipelineDefaultPolicy +{ + static constexpr ck_tile::index_t NumWarpPerGroup = 4; + static constexpr ck_tile::index_t NumThreadPerWarpGroup = + NumWarpPerGroup * ck_tile::get_warp_size(); + + // TODO: GetAlignment*() currently didn't consider if need padding or not + // so in pipeline still need check padding requirement + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); + } + + template + CK_TILE_DEVICE static constexpr auto GetAlignmentK() + { + using namespace ck_tile; + using KDataType = remove_cvref_t; +#if defined(__gfx950__) + constexpr index_t MaxReadSizeInBytes = 16; +#else + constexpr index_t MaxReadSizeInBytes = 4; +#endif + return MaxReadSizeInBytes / sizeof(KDataType); + } + + template + CK_TILE_DEVICE static constexpr auto GetAlignmentV() + { + using namespace ck_tile; + using VDataType = remove_cvref_t; +#if defined(__gfx950__) + constexpr index_t MaxReadSizeInBytes = 16; +#else + constexpr index_t MaxReadSizeInBytes = 4; +#endif + return MaxReadSizeInBytes / sizeof(VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + using namespace ck_tile; + + // TODO: this is for 3d layout + using KDataType = remove_cvref_t; + return 16 / sizeof(KDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemVPackK() + { + using namespace ck_tile; + + // TODO: this is for 3d layout + using VDataType = remove_cvref_t; + return 16 / sizeof(VDataType); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKDramTileDistribution() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KVector = GetAlignmentK(); // this is for global load + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KVector = GetAlignmentV(); // this is for global load + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr index_t N0 = NumIssues; + constexpr index_t N1 = LaneGroups; + constexpr index_t N2 = NumWarps; + constexpr index_t K0 = LanesPerK; + constexpr index_t K1 = KVector; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeQRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakeKRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakePRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + + return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + } + + template + CK_TILE_DEVICE static constexpr auto MakeVRegTileDistribution() + { + using namespace ck_tile; + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + // compute the endcoding before transpose + constexpr auto v_block_dstr = + make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(v_block_dstr_encode), + typename Problem::VDataType>::TransposedDstrEncode{}); + + return v_block_dstr; + } + + template + CK_TILE_DEVICE static constexpr auto GetQKBlockGemm() + { + using namespace ck_tile; + + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + + constexpr auto warp_gemm = []() { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use + /// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here + return WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + /// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use + /// WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution here + return WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>{}; + } + }(); + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV2CustomPolicy; + + return BlockGemmARegBRegCRegV2{}; + } + + template + CK_TILE_DEVICE static constexpr auto GetPVBlockGemm() + { + using namespace ck_tile; + + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; + /// NOTICE: in order to use load_tile_transpose() later for V tiles, we have to pass + /// WGAttrNumAccessEnum::Double instead of WGAttrNumAccessEnum::Single + using WarpGemm = WarpGemmDispatcher{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Double>; + + using BlockGemmPolicy = + BlockGemmARegBRegCRegV2CustomPolicy; + return BlockGemmARegBRegCRegV2{}; + } + + static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords + static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords + + template + CK_TILE_DEVICE static constexpr auto + MakeKLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) + { + using namespace ck_tile; + + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + [[maybe_unused]] constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kKLdsPadInBytes / + sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps. + // Optimize this for lds_read speed + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + WarpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number()>{}, + number{}, + number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return k_lds_block_desc_issues_warps_lanes; + } + + template + CK_TILE_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor() + { + using namespace ck_tile; + + // K is always k-major, we use async-copy to load into LDS + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kKLdsPadInBytes / + sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + // this function assume K/V can share smem + constexpr index_t SingleKSize = [&]() { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemKPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = KPack; + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; + constexpr index_t LaneGroups = WarpSize / LanesPerK; + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + + return NumIssues * NumWarps * (WarpSize * KVector + kPad); + }(); + + constexpr index_t SingleVSize = [&]() { + using VDataType = remove_cvref_t; + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); + constexpr index_t kKPack = GetSmemKPackK(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + }(); + + return max(SingleKSize, SingleVSize); + } + + template + CK_TILE_DEVICE static constexpr auto + MakeVLdsStoreBlockDescriptor(ck_tile::number = ck_tile::number<0>{}) + { + using namespace ck_tile; + + /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + [[maybe_unused]] constexpr index_t KPack = GetSmemVPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentV(); // this is for global load + constexpr index_t kPad = + kVLdsPadInBytes / + sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps. + // Optimize this for lds_read speed + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = + kKPerBlock / KVector; // how many lane (within a wave) to load K + constexpr index_t LaneGroups = + WarpSize / + LanesPerK; // how many groups (within a wave), they may load different N, but same K + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset( + make_tuple(number{}, // n0 + number{}, // n1 + number{}, // n2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number<(IBuf + 2) * GetSingleSmemElementSpaceSize()>{}, + number{}, + number<1>{}); + + // TODO this layout is hard coded, and will be used in async copy buffer view load + // in LDS the real layout is (bufs, N0, N2, N1*K0*K1) + constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return v_lds_block_desc_issues_warps_lanes; + } + + template + CK_TILE_DEVICE static constexpr auto MakeVLdsLoadBlockDescriptor() + { + using namespace ck_tile; + + /// FIXME: rename the kNPerBlock & kKPerBlock since the kN1 is congtigous dimension + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; + constexpr index_t WarpSize = ck_tile::get_warp_size(); + + constexpr index_t KPack = GetSmemVPackK(); // this is for lds + constexpr index_t KVector = GetAlignmentK(); // this is for global load + constexpr index_t kPad = + kVLdsPadInBytes / + sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps + + static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0); + constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave + constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave + constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps); + static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector)); + + constexpr auto v_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // n0 + number{}, // n2 + number{}, // n1 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() + { + using namespace ck_tile; + + static_assert(MakeKLdsLoadBlockDescriptor().get_element_space_size() == + MakeKLdsStoreBlockDescriptor().get_element_space_size()); + constexpr index_t k_element_space_size = + MakeKLdsLoadBlockDescriptor().get_element_space_size(); + + static_assert(MakeVLdsLoadBlockDescriptor().get_element_space_size() == + MakeVLdsStoreBlockDescriptor().get_element_space_size()); + constexpr index_t v_element_space_size = + MakeVLdsLoadBlockDescriptor().get_element_space_size(); + + static_assert(ck_tile::max(k_element_space_size, v_element_space_size) <= + GetSingleSmemElementSpaceSize()); + + /// TODO: override GetSingleSmemElementSpaceSize() to align with MakeKLdsBlockDescriptor() & + /// MakeVLdsBlockDescriptor() + static_assert(std::is_same_v); + constexpr index_t kv_element_space_size_in_bytes = + GetSingleSmemElementSpaceSize() * sizeof(typename Problem::KDataType); + + return kv_element_space_size_in_bytes; + } + + template + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return 4 * GetSmemSizeKV(); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 86ac713b6f..7775848195 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" namespace ck_tile { @@ -262,4 +263,47 @@ struct BlockFmhaFwdAppendKVPipelineProblem static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; +template +struct BlockFmhaFwdV3PipelineProblem +{ + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using Traits = remove_cvref_t; + + static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps; + static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps; + static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); + + static constexpr bool kIsGroupMode = kIsGroupMode_; + + // attributes from traits + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kStoreLSE = Traits::kStoreLSE; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index fb4713ccc0..cd3893f5cf 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -148,4 +148,20 @@ struct TileFmhaBwdConvertQGradTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; +template +struct TileFmhaFwdV3Traits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadSeqLenK = kPadSeqLenK_; + static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + } // namespace ck_tile From 33418b201f53259ebc192441eedf1098056ba6a7 Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Tue, 2 Sep 2025 11:18:53 +0800 Subject: [PATCH 003/404] Fix naming issue (#2762) --- include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index ddc5c5447f..9d848dfd7a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -65,9 +65,9 @@ struct FmhaFwdKernel static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad; #if defined(__gfx950__) - static constexpr bool kIsAvialable = true; + static constexpr bool kIsAvailable = true; #else - static constexpr bool kIsAvialable = !kUseTrLoad; + static constexpr bool kIsAvailable = !kUseTrLoad; #endif static constexpr std::string_view kPipelineName = FmhaPipeline::name; @@ -1046,7 +1046,7 @@ struct FmhaFwdKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { - if constexpr(kIsAvialable) + if constexpr(kIsAvailable) run_(std::move(kargs)); } From 022f369deb06e202f6a0dd72b6759c9332e6d395 Mon Sep 17 00:00:00 2001 From: Michael Mcminn <47832147+UD-mmcminn@users.noreply.github.com> Date: Tue, 2 Sep 2025 04:35:07 -0400 Subject: [PATCH 004/404] =?UTF-8?q?Adding=20fix=20for=20the=20gfx908=20to?= =?UTF-8?q?=20the=20GEMM=20MFMA=20implementaitons=20of=20WarpGem=E2=80=A6?= =?UTF-8?q?=20(#2751)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Adding fix for the gfx908 to the GEMM MFMA implementaitons of WarpGemmMfmaBf16Bf16F32M4N64K16 WarpGemmMfmaBf16Bf16F32M64N4K16 * Adding support for offload target gfx9-4-generic * This duplication here isn't ideal --- include/ck/ck.hpp | 5 +- include/ck_tile/core/config.hpp | 5 +- .../warp/warp_gemm_attribute_mfma_impl.hpp | 60 +++++++++++++++++-- 3 files changed, 62 insertions(+), 8 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 09801203ba..b8a1afec4e 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -50,10 +50,11 @@ #endif // define general macros for various architectures -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \ + defined(__gfx950__) || defined(__gfx9_4_generic__) #define __gfx9__ #endif -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__) #define __gfx94__ #endif #if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 7b5b862cb1..0d4aa58026 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -3,10 +3,11 @@ #pragma once -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \\ + defined(__gfx950__) || defined(__gfx9_4_generic__) #define __gfx9__ #endif -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__) #define __gfx94__ #endif #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 0831cf85c4..11a8416fb2 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -660,8 +660,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl) else { -#if defined(__gfx9__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx908__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); #else ignore = c_vec; ignore = a_vec; @@ -673,9 +685,23 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx9__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; #else ignore = a_vec; ignore = b_vec; @@ -724,8 +750,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl) else { -#if defined(__gfx9__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx908__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); #else ignore = c_vec; ignore = a_vec; @@ -737,9 +775,23 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx9__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; #else ignore = a_vec; ignore = b_vec; From 4419fc34a299b70e6fc9b8894a7efc92be173226 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Tue, 2 Sep 2025 14:14:10 +0300 Subject: [PATCH 005/404] Fix formatting problem (#2768) --- include/ck/ck.hpp | 4 ++-- include/ck_tile/core/config.hpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index b8a1afec4e..5783605f8d 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -50,8 +50,8 @@ #endif // define general macros for various architectures -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \ - defined(__gfx950__) || defined(__gfx9_4_generic__) +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) || \ + defined(__gfx9_4_generic__) #define __gfx9__ #endif #if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 0d4aa58026..a7fe6b37e1 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -3,8 +3,8 @@ #pragma once -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || \\ - defined(__gfx950__) || defined(__gfx9_4_generic__) +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) || \ + defined(__gfx9_4_generic__) #define __gfx9__ #endif #if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__) From bab747b017f3f2102d59f15f08953055d5edc0f4 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Wed, 3 Sep 2025 00:12:24 +0800 Subject: [PATCH 006/404] Fix typo in profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp (#2767) --- profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp index 32bdf05771..33a889afe7 100644 --- a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -107,7 +107,7 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, ck::utils::validate_gemm_stride(M, K, StrideA, "StrideA"); ck::utils::validate_gemm_stride(K, N, StrideB, "StrideB"); - ck::utils::validate_gemm_stride(M, N, StrideE, "StrideE"); + ck::utils::validate_gemm_stride(M, N, StrideE, "StrideE"); Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor a1_m_k(f_host_tensor_descriptor((M + ScaleBlockM - 1) / ScaleBlockM, From 0e322200e5c959bddea8dda101197f685bc3c22c Mon Sep 17 00:00:00 2001 From: Vidyasagar Ananthan Date: Tue, 2 Sep 2025 11:32:01 -0400 Subject: [PATCH 007/404] Fixing python backward compatibility issue in benchmarking script. --- tile_engine/ops/gemm/gemm_instance_builder.py | 3 ++- tile_engine/ops/gemm/validation_utils.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index d679be7b84..c2214da613 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -8,6 +8,7 @@ import multiprocessing import concurrent.futures from pathlib import Path import logging +from typing import Optional from validation_utils import is_tile_config_valid, is_trait_combination_valid logging.basicConfig(level=logging.INFO) @@ -325,7 +326,7 @@ class GemmKernelBuilder: "c": "ck_tile::tensor_layout::gemm::ColumnMajor", } - def _get_abc_layouts(self, layout_code: str | None = None): + def _get_abc_layouts(self, layout_code: Optional[str] = None): """ Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. If layout_code is None, use self.layout. diff --git a/tile_engine/ops/gemm/validation_utils.py b/tile_engine/ops/gemm/validation_utils.py index 4948fd5744..7367f2446d 100644 --- a/tile_engine/ops/gemm/validation_utils.py +++ b/tile_engine/ops/gemm/validation_utils.py @@ -11,6 +11,7 @@ import subprocess import re from functools import lru_cache import logging +from typing import Tuple, List # Element size mapping for different data types ELEMENT_SIZE_MAP = { @@ -169,7 +170,7 @@ def validate_dimension_alignment( warp_tile_m: int, warp_tile_n: int, warp_tile_k: int, -) -> tuple[bool, list[str]]: +) -> Tuple[bool, List[str]]: """Check if tile dimensions are properly aligned with warp dimensions.""" alignment_issues = [] @@ -196,7 +197,7 @@ def validate_lds_capacity( a_datatype: str, b_datatype: str, pipeline: str, -) -> tuple[bool, str]: +) -> Tuple[bool, str]: """Validate LDS capacity requirements.""" matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) @@ -224,7 +225,7 @@ def validate_warp_tile_combination( b_datatype: str, c_datatype: str, gpu_name: str = None, -) -> tuple[bool, str]: +) -> Tuple[bool, str]: """Validate warp tile combination against GPU-specific supported combinations.""" if gpu_name is None: gpu_name = get_gpu_name_by_id(0) From 9f35cde374381ba76ea793d0794ac31ced075bb0 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Wed, 3 Sep 2025 00:51:56 +0800 Subject: [PATCH 008/404] [CK_TILE] Fix fmha_fwd_v3() Default2DEpilogue usage (#2765) * Fix Default2DEpilogue usage * Fix Default2DEpilogue usage for batch_prefill --- include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp | 2 +- include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 2850ce3379..fcd512056d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -1143,7 +1143,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel make_tuple(number{}, number{}), {i_m0, i_n1}); - EpiloguePipeline{}(o_dram_window, o_acc_tile); + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); } }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index be14a36353..87021354aa 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -513,7 +513,7 @@ struct FmhaFwdV3Kernel make_tuple(number{}, number{}), {i_m0, i_n1}); - EpiloguePipeline{}(o_dram_window, o_acc_tile); + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); } }; } // namespace ck_tile From 8d43155bce73226b0030dcfbb12f95e62c4abe46 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Tue, 2 Sep 2025 14:04:21 -0400 Subject: [PATCH 009/404] fix the errors (#2771) --- script/cmake-ck-dev.sh | 6 ------ 1 file changed, 6 deletions(-) diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 217ec998bd..086359a79f 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -28,16 +28,10 @@ if [ $# -ge 1 ]; then REST_ARGS=("$@") ;; *) - echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" - GPU_TARGETS="gfx908;gfx90a;gfx942" - shift 1 REST_ARGS=("$@") ;; esac else - echo "No GPU targets provided, using default targets: gfx908;gfx90a;gfx942" - GPU_TARGETS="gfx908;gfx90a;gfx942" - shift 1 REST_ARGS=("$@") fi From 00fd72b2d4c807c97b0adec6cec4986d098ce4fa Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Wed, 3 Sep 2025 08:07:09 +0800 Subject: [PATCH 010/404] Fix a typo in intrin_wmma_bf16_16x16x16_bf16_w32 (#2727) __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32 is only available in gfx11. --- include/ck/utility/amd_wmma.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index e14c0d62a8..09a462d016 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -104,7 +104,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx11__) || defined(__gfx12__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); From e1ab460d2d2f58c3bfc18f1ff360a34aeb7f478f Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Tue, 2 Sep 2025 23:40:18 -0600 Subject: [PATCH 011/404] [CK TILE GEMM] Fix building issues (#2772) - Add `WarpGemmMfma_f32_16x16x128_[fp8|bf8]_[fp8|bf8]_CTransposed` - Replace `__gfx950__` with `CK_GFX950_SUPPORT` --- .../38_block_scale_gemm/gemm_utils.hpp | 4 ++-- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 24 +++++++++++++++++++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 4 ++++ .../test_gemm_aquant_utils.hpp | 2 +- 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index d64297cb35..930cdefb7e 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -18,7 +18,7 @@ template constexpr ck_tile::index_t get_k_warp_tile() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) constexpr bool is_8bit_float = std::is_same_v || std::is_same_v; if constexpr(M_Warp_Tile == 32) @@ -35,7 +35,7 @@ constexpr ck_tile::index_t get_k_warp_tile() template constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) if constexpr(M_Warp_Tile == 32) return sizeof(PrecType) == 2 ? 16 : 64; else diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 87772f78fc..f83bbc2a18 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -301,6 +301,30 @@ using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; +template +using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + +template +using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed = + WarpGemmImpl, + AttrNumAccess>>; + template using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl< WarpGemmAttributeMfma, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 5021fb9907..1d3dd2ae6f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -112,6 +112,10 @@ template<> struct WarpGemmDispatcher struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed<>; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; }; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp index cf9bf18c5a..61bb1a8bdd 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp @@ -18,7 +18,7 @@ template constexpr ck_tile::index_t get_k_warp_tile() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) constexpr bool is_8bit_float = std::is_same_v || std::is_same_v; if constexpr(M_Warp_Tile == 32) From 4d041837ade7ae01900a0442d939f80b723b1631 Mon Sep 17 00:00:00 2001 From: rahjain-amd Date: Wed, 3 Sep 2025 12:01:29 +0530 Subject: [PATCH 012/404] Add json dump support to output details from CK/CKTile Examples. (#2551) * Adding RapidJson Library * Adding Json Dumps in all CK_Tile Examples Not verified yet * Adding json to cktile Batched Transpose * adding json dumps to layernorm2d_fwd * Adding json dump to flatmm_basic * Adding RapidJson Library * Adding Json Dumps in all CK_Tile Examples Not verified yet * Adding json to cktile Batched Transpose * adding json dumps to layernorm2d_fwd * Adding json dump to flatmm_basic * Adding json in 03_gemm * Add json dump to 16_batched_gemm * Add json dump to gemm_multi_d_fp16 * Add json dump to grouped_gemm * fix fmha_bwd/fwd * Fix clang-format errors exclude include/rapidjson in jenkins as its a third-party library * Saparating function and defination. * Update Documentation of 03_gemm * Refactoring as per code review * Disable fp8 instances on unsupported targets (#2592) * Restrict building of gemm_universal_preshuffle_f8 instances to specific targets in CMakeLists.txt * Add condition to skip gemm_xdl_universal_preshuffle_f8 instances for unsupported targets in CMakeLists.txt * Add conditions to skip unsupported targets for gemm_universal_preshuffle_f8 and gemm_xdl_universal_preshuffle_f8 instances in CMakeLists.txt * Refine conditions to exclude gemm_universal_preshuffle_f8 instances for unsupported targets in CMakeLists.txt --------- Co-authored-by: AviralGoelAMD * fix clang format * remove duplicate lines of code from library/src/tensor_operation_instance/gpu/CMakeLists.txt * Fixing Readme and unifying jsondumps * adding moe_smoothquant * adding fused_moe * Fixing Readme for batched_gemm * Fixing Readme for grouped_gemm * adding flatmm * adding gemm_multi_d_fp16 * adding elementwise * adding File name when json is dumped * Fixing Reduce after merge * adding batched_transpose * Adding Warptile in Gemm * Fixing Clang Format --------- Co-authored-by: Aviral Goel Co-authored-by: AviralGoelAMD Co-authored-by: illsilin_amdeng --- Jenkinsfile | 2 + example/CMakeLists.txt | 1 + example/ck_tile/01_fmha/README.md | 2 + example/ck_tile/01_fmha/fmha_bwd.cpp | 674 ++-- example/ck_tile/01_fmha/fmha_fwd.cpp | 488 +-- example/ck_tile/02_layernorm2d/README.md | 2 + .../02_layernorm2d/layernorm2d_fwd.cpp | 23 +- example/ck_tile/03_gemm/README.md | 14 +- example/ck_tile/03_gemm/gemm_utils.hpp | 3 + example/ck_tile/03_gemm/run_gemm_example.inc | 106 +- example/ck_tile/05_reduce/reduce.cpp | 28 +- example/ck_tile/06_permute/permute.cpp | 10 +- example/ck_tile/09_topk_softmax/README.md | 2 + .../ck_tile/09_topk_softmax/topk_softmax.cpp | 24 +- .../09_topk_softmax/topk_softmax_api.cpp | 2 +- example/ck_tile/10_rmsnorm2d/README.md | 31 +- .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 22 +- .../11_add_rmsnorm2d_rdquant/README.md | 14 +- .../add_rmsnorm2d_rdquant_fwd.cpp | 20 +- example/ck_tile/12_smoothquant/README.md | 13 +- .../ck_tile/12_smoothquant/smoothquant.cpp | 18 +- example/ck_tile/13_moe_sorting/README.md | 48 +- .../ck_tile/13_moe_sorting/moe_sorting.cpp | 24 +- example/ck_tile/14_moe_smoothquant/README.md | 22 +- .../14_moe_smoothquant/moe_smoothquant.cpp | 20 +- example/ck_tile/15_fused_moe/README.md | 38 + example/ck_tile/15_fused_moe/main.cpp | 51 +- example/ck_tile/16_batched_gemm/README.md | 40 +- .../ck_tile/16_batched_gemm/batched_gemm.hpp | 5 +- .../run_batched_gemm_example.inc | 101 +- example/ck_tile/17_grouped_gemm/README.md | 29 +- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 5 +- .../run_grouped_gemm_example.inc | 69 +- example/ck_tile/18_flatmm/README.md | 23 +- example/ck_tile/18_flatmm/flatmm_basic.hpp | 7 +- .../ck_tile/18_flatmm/run_flatmm_example.inc | 83 +- example/ck_tile/19_gemm_multi_d/README.md | 32 +- .../19_gemm_multi_d/gemm_multi_d_fp16.hpp | 4 +- .../run_gemm_multi_d_fp16_example.inc | 113 +- .../grouped_convolution_utils.hpp | 3 +- .../21_elementwise/elementwise_example.cpp | 17 +- .../elementwise_example_add_4d.cpp | 17 +- .../elementwise_example_transpose.cpp | 17 +- .../elementwise_example_unary.cpp | 17 +- .../batched_transpose_example.cpp | 20 + example/include/json_dump.hpp | 700 ++++ include/rapidjson/allocators.h | 693 ++++ include/rapidjson/cursorstreamwrapper.h | 78 + include/rapidjson/document.h | 3044 +++++++++++++++ include/rapidjson/encodedstream.h | 299 ++ include/rapidjson/encodings.h | 716 ++++ include/rapidjson/error/en.h | 176 + include/rapidjson/error/error.h | 285 ++ include/rapidjson/filereadstream.h | 99 + include/rapidjson/filewritestream.h | 104 + include/rapidjson/fwd.h | 151 + include/rapidjson/internal/biginteger.h | 297 ++ include/rapidjson/internal/clzll.h | 71 + include/rapidjson/internal/diyfp.h | 261 ++ include/rapidjson/internal/dtoa.h | 249 ++ include/rapidjson/internal/ieee754.h | 78 + include/rapidjson/internal/itoa.h | 308 ++ include/rapidjson/internal/meta.h | 186 + include/rapidjson/internal/pow10.h | 55 + include/rapidjson/internal/regex.h | 739 ++++ include/rapidjson/internal/stack.h | 232 ++ include/rapidjson/internal/strfunc.h | 83 + include/rapidjson/internal/strtod.h | 293 ++ include/rapidjson/internal/swap.h | 46 + include/rapidjson/istreamwrapper.h | 128 + include/rapidjson/memorybuffer.h | 70 + include/rapidjson/memorystream.h | 71 + include/rapidjson/msinttypes/inttypes.h | 316 ++ include/rapidjson/msinttypes/stdint.h | 300 ++ include/rapidjson/ostreamwrapper.h | 81 + include/rapidjson/pointer.h | 1482 ++++++++ include/rapidjson/prettywriter.h | 277 ++ include/rapidjson/rapidjson.h | 741 ++++ include/rapidjson/reader.h | 2246 ++++++++++++ include/rapidjson/schema.h | 3261 +++++++++++++++++ include/rapidjson/stream.h | 223 ++ include/rapidjson/stringbuffer.h | 121 + include/rapidjson/uri.h | 481 +++ include/rapidjson/writer.h | 721 ++++ script/clang-format-overwrite.sh | 9 +- .../moe_sorting/test_moe_sorting_cases.inc | 0 test/ck_tile/permute/test_permute_cases.inc | 0 .../smoothquant/test_smoothquant_cases.inc | 0 88 files changed, 21219 insertions(+), 856 deletions(-) create mode 100644 example/include/json_dump.hpp create mode 100644 include/rapidjson/allocators.h create mode 100644 include/rapidjson/cursorstreamwrapper.h create mode 100644 include/rapidjson/document.h create mode 100644 include/rapidjson/encodedstream.h create mode 100644 include/rapidjson/encodings.h create mode 100644 include/rapidjson/error/en.h create mode 100644 include/rapidjson/error/error.h create mode 100644 include/rapidjson/filereadstream.h create mode 100644 include/rapidjson/filewritestream.h create mode 100644 include/rapidjson/fwd.h create mode 100644 include/rapidjson/internal/biginteger.h create mode 100644 include/rapidjson/internal/clzll.h create mode 100644 include/rapidjson/internal/diyfp.h create mode 100644 include/rapidjson/internal/dtoa.h create mode 100644 include/rapidjson/internal/ieee754.h create mode 100644 include/rapidjson/internal/itoa.h create mode 100644 include/rapidjson/internal/meta.h create mode 100644 include/rapidjson/internal/pow10.h create mode 100644 include/rapidjson/internal/regex.h create mode 100644 include/rapidjson/internal/stack.h create mode 100644 include/rapidjson/internal/strfunc.h create mode 100644 include/rapidjson/internal/strtod.h create mode 100644 include/rapidjson/internal/swap.h create mode 100644 include/rapidjson/istreamwrapper.h create mode 100644 include/rapidjson/memorybuffer.h create mode 100644 include/rapidjson/memorystream.h create mode 100644 include/rapidjson/msinttypes/inttypes.h create mode 100644 include/rapidjson/msinttypes/stdint.h create mode 100644 include/rapidjson/ostreamwrapper.h create mode 100644 include/rapidjson/pointer.h create mode 100644 include/rapidjson/prettywriter.h create mode 100644 include/rapidjson/rapidjson.h create mode 100644 include/rapidjson/reader.h create mode 100644 include/rapidjson/schema.h create mode 100644 include/rapidjson/stream.h create mode 100644 include/rapidjson/stringbuffer.h create mode 100644 include/rapidjson/uri.h create mode 100644 include/rapidjson/writer.h mode change 100755 => 100644 test/ck_tile/moe_sorting/test_moe_sorting_cases.inc mode change 100755 => 100644 test/ck_tile/permute/test_permute_cases.inc mode change 100755 => 100644 test/ck_tile/smoothquant/test_smoothquant_cases.inc diff --git a/Jenkinsfile b/Jenkinsfile index e7e57aded9..4350816013 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1129,6 +1129,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ + | grep -v 'include/rapidjson' \ | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\' && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ @@ -1158,6 +1159,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ + | grep -v 'include/rapidjson' \ | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\'" } steps{ diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 7bd628edf2..7dc2f92bf9 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/library/include + ${PROJECT_SOURCE_DIR}/example/include ) add_custom_target(examples) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index f72d7afa02..cb6cd44f64 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -74,6 +74,8 @@ args: -num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1) -warmup number of iterations before benchmark the kernel (default:5) -repeat number of iterations to benchmark the kernel (default:20) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:fmha_fwd.json) ``` Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 9f1e0f6948..8ad2a3de04 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -5,6 +5,7 @@ #include "ck_tile/host.hpp" #include "mask.hpp" #include "utils.hpp" +#include "json_dump.hpp" #include #include @@ -94,7 +95,9 @@ auto create_args(int argc, char* argv[]) .insert("deterministic", "0", "if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion " - "will not be used"); + "will not be used") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "fmha_bwd.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -584,53 +587,54 @@ bool run(const ck_tile::ArgParser& arg_parser) << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec << " GB/s" << std::flush; + bool pass = true; if(!do_validation) { std::cout << std::flush << std::endl; - return true; } - - bool pass = true; - - std::vector> q_host_refs; - std::vector> k_host_refs; - std::vector> v_host_refs; - std::vector> o_host_refs; - std::vector> randval_host_refs; - std::vector> p_hp_host_refs; - std::vector> p_lp_host_refs; - - randval_buf.FromDevice(randval_host.data()); - - for(ck_tile::index_t wb = 0; wb < batch; ++wb) + else { - const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + std::vector> q_host_refs; + std::vector> k_host_refs; + std::vector> v_host_refs; + std::vector> o_host_refs; + std::vector> randval_host_refs; + std::vector> p_hp_host_refs; + std::vector> p_lp_host_refs; - // adjust matrix index according to the mode - const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); - const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); - const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + randval_buf.FromDevice(randval_host.data()); - ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k - ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k - ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n - ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o - ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); // lse_g_m - ck_tile::HostTensor randval_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n - ck_tile::HostTensor s_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n - ck_tile::HostTensor p_hp_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision - ck_tile::HostTensor p_dropped_hp_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision - ck_tile::HostTensor p_lp_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; - ck_tile::index_t nr = nhead / nhead_k; + // adjust matrix index according to the mode + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = + (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = + (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); - // clang-format off + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o + ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); // lse_g_m + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n + ck_tile::HostTensor s_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n + ck_tile::HostTensor p_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision + ck_tile::HostTensor p_dropped_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision + ck_tile::HostTensor p_lp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision + + ck_tile::index_t nr = nhead / nhead_k; + + // clang-format off // permute if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); @@ -642,281 +646,294 @@ bool run(const ck_tile::ArgParser& arg_parser) if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); - // clang-format on - - // reference - // S = scale * Q * K^T - ck_tile::reference_batched_gemm( - q_host_ref, - k_host_ref, - s_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k - - if(bias.type == bias_enum::elementwise_bias) - { - // elementwise bias - ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); - // clang-format off - if(i_perm) - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); - else - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); // clang-format on - // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, - // real_seqlen_k] - ck_tile:: - reference_batched_elementwise( - s_host_ref, bias_host_ref, s_host_ref); - } - else if(bias.type == bias_enum::alibi) - { - // alibi construct elementwise bias to verify - auto alibi_host = [&]() { - if(mask.type != mask_enum::no_mask) - { - return ck_tile::make_alibi_from_lr_mask( - 0, - mask.left, - mask.right, - real_seqlen_q, - real_seqlen_k, - static_cast(mask.type)); - } - else - { - return ck_tile::Alibi{ - 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; - } - }(); + // reference + // S = scale * Q * K^T + ck_tile::reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k - ck_tile::HostTensor alibi_bias_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); - auto i_b_slope = bias.rank_info == 0 ? 0 : wb; - for(auto i_h = 0; i_h < nhead; i_h++) + if(bias.type == bias_enum::elementwise_bias) { - AccDataType current_slope = alibi_slope_host(i_b_slope, i_h); - alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope - : -current_slope; - for(auto i_r = 0; i_r < real_seqlen_q; i_r++) - { - for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + // elementwise bias + ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) { - AccDataType pixel = 0; - alibi_host.update(pixel, i_r, i_c); - alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + AccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL + ? current_slope + : -current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + AccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } } } - } - // [nhead, real_seqlen_q, real_seqlen_k] - ck_tile:: - reference_batched_elementwise( + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile::reference_batched_elementwise( s_host_ref, alibi_bias_host_ref, s_host_ref); - } + } - if(mask.type == mask_enum::no_mask) - { - ck_tile::reference_batched_masking( - s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); - } - else if(mask.type == mask_enum::window_generic) - { - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, real_seqlen_q, real_seqlen_k)); - } - else - { - // if left window size is negative, means causal - // else means generic (for current batch) - if(mask.left < 0) + if(mask.type == mask_enum::no_mask) + { ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - real_seqlen_q, - real_seqlen_k, - mask.type == mask_enum::mask_top_left)); - else + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { ck_tile::reference_batched_masking( s_host_ref, ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - real_seqlen_q, - real_seqlen_k, - mask.type == mask_enum::mask_top_left)); - } - ck_tile::reference_batched_softmax( - s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref); + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + } + ck_tile::reference_batched_softmax( + s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref); - if(p_drop > 0) - { - p_dropped_hp_host_ref = p_hp_host_ref; - randval_host_ref.ForEach([&](auto& self, auto idx) { - self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); - }); - ck_tile::reference_batched_dropout( - p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); - p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType(); - } - else - { - p_lp_host_ref = p_hp_host_ref.template CopyAsType(); - } + if(p_drop > 0) + { + p_dropped_hp_host_ref = p_hp_host_ref; + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); + }); + ck_tile::reference_batched_dropout( + p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType(); + } + else + { + p_lp_host_ref = p_hp_host_ref.template CopyAsType(); + } - // O = P * V - ck_tile::reference_batched_gemm( - p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n + // O = P * V + ck_tile::reference_batched_gemm( + p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n - // clang-format off + // clang-format off // permute if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); }); - // clang-format on + // clang-format on - q_host_refs.push_back(q_host_ref); - k_host_refs.push_back(k_host_ref); - v_host_refs.push_back(v_host_ref); - o_host_refs.push_back(o_host_ref); - p_hp_host_refs.push_back(p_hp_host_ref); - p_lp_host_refs.push_back(p_lp_host_ref); - if(p_drop > 0) - { - randval_host_refs.push_back(randval_host_ref); + q_host_refs.push_back(q_host_ref); + k_host_refs.push_back(k_host_ref); + v_host_refs.push_back(v_host_ref); + o_host_refs.push_back(o_host_ref); + p_hp_host_refs.push_back(p_hp_host_ref); + p_lp_host_refs.push_back(p_lp_host_ref); + if(p_drop > 0) + { + randval_host_refs.push_back(randval_host_ref); + } } - } - // set to bad values to check if the kernel writes to these buffers - ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_host); - ck_tile::FillConstant{ck_tile::numeric::infinity()}(dk_host); - ck_tile::FillConstant{ck_tile::numeric::infinity()}(dv_host); - dq_buf.ToDevice(dq_host.data()); - dk_buf.ToDevice(dk_host.data()); - dv_buf.ToDevice(dv_host.data()); + // set to bad values to check if the kernel writes to these buffers + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_host); + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dk_host); + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dv_host); + dq_buf.ToDevice(dq_host.data()); + dk_buf.ToDevice(dk_host.data()); + dv_buf.ToDevice(dv_host.data()); - o_buf.ToDevice(o_host.data()); - lse_buf.ToDevice(lse_host.data()); - dq_buf.SetZero(); - dbias_buf.SetZero(); - dq_acc_buf.SetZero(); + o_buf.ToDevice(o_host.data()); + lse_buf.ToDevice(lse_host.data()); + dq_buf.SetZero(); + dbias_buf.SetZero(); + dq_acc_buf.SetZero(); - ck_tile::stream_config stream_config_v{ - nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; - fmha_bwd(fmha_traits, fmha_args, stream_config_v); + ck_tile::stream_config stream_config_v{ + nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; + fmha_bwd(fmha_traits, fmha_args, stream_config_v); - dq_buf.FromDevice(dq_host.data()); - dk_buf.FromDevice(dk_host.data()); - dv_buf.FromDevice(dv_host.data()); - dbias_buf.FromDevice(dbias_host.data()); + dq_buf.FromDevice(dq_host.data()); + dk_buf.FromDevice(dk_host.data()); + dv_buf.FromDevice(dv_host.data()); + dbias_buf.FromDevice(dbias_host.data()); - for(ck_tile::index_t wb = 0; wb < batch; ++wb) - { - const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; - // adjust matrix index according to the mode - const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); - const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); - const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + // adjust matrix index according to the mode + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = + (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = + (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); - ck_tile::HostTensor do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o - ck_tile::HostTensor ds_hp_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision - ck_tile::HostTensor ds_lp_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision - ck_tile::HostTensor dp_hp_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision - ck_tile::HostTensor dbias_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n - ck_tile::HostTensor dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k - ck_tile::HostTensor dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k - ck_tile::HostTensor dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + ck_tile::HostTensor do_host_ref( + {nhead, real_seqlen_q, hdim_v}); // do_g_m_o + ck_tile::HostTensor ds_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision + ck_tile::HostTensor ds_lp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision + ck_tile::HostTensor dp_hp_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision + ck_tile::HostTensor dbias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + ck_tile::HostTensor dq_host_ref( + {nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + ck_tile::HostTensor dk_host_ref( + {nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + ck_tile::HostTensor dv_host_ref( + {nhead, real_seqlen_k, hdim_v}); // dv_g_n_o - // clang-format off + // clang-format off if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); }); else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); }); - // clang-format on + // clang-format on - // dP = dO@V x Z w/ dropout - // dP = dO@V w/o dropout - auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o - ck_tile::reference_batched_gemm( - do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o + // dP = dO@V x Z w/ dropout + // dP = dO@V w/o dropout + auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o + ck_tile::reference_batched_gemm( + do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o - if(p_drop > 0) - { - ck_tile::reference_batched_dropout( - dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop); - } + if(p_drop > 0) + { + ck_tile::reference_batched_dropout( + dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop); + } - // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) - ck_tile::make_ParallelTensorFunctor( - [&](auto i0, auto i1, auto i2) { - AccDataType do_dot_o = 0; - for(int o = 0; o < hdim_v; o++) - { - do_dot_o += ck_tile::type_convert(do_host_ref(i0, i1, o)) * - ck_tile::type_convert(o_host_refs[wb](i0, i1, o)); - } - ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert( - p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o)); - }, - ds_hp_host_ref.mDesc.get_lengths()[0], - ds_hp_host_ref.mDesc.get_lengths()[1], - ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency()); + // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) + ck_tile::make_ParallelTensorFunctor( + [&](auto i0, auto i1, auto i2) { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + do_dot_o += ck_tile::type_convert(do_host_ref(i0, i1, o)) * + ck_tile::type_convert(o_host_refs[wb](i0, i1, o)); + } + ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert( + p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o)); + }, + ds_hp_host_ref.mDesc.get_lengths()[0], + ds_hp_host_ref.mDesc.get_lengths()[1], + ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency()); - if(use_dbias) - { - dbias_host_ref = ds_hp_host_ref.template CopyAsType(); - } + if(use_dbias) + { + dbias_host_ref = ds_hp_host_ref.template CopyAsType(); + } - ds_lp_host_ref = ds_hp_host_ref.template CopyAsType(); + ds_lp_host_ref = ds_hp_host_ref.template CopyAsType(); - // dV = P_drop^T@dO^T - // dV = P^T@dO^T w/o dropout - auto p_t_lp_host_ref = p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m - auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m - ck_tile::reference_batched_gemm( - p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m + // dV = P_drop^T@dO^T + // dV = P^T@dO^T w/o dropout + auto p_t_lp_host_ref = + p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m + auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m + ck_tile:: + reference_batched_gemm( + p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m - // dQ = scale * dS@K^T - auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n - ck_tile::reference_batched_gemm( - ds_lp_host_ref, - k_t_host_ref, - dq_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n + // dQ = scale * dS@K^T + auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n + ck_tile::reference_batched_gemm( + ds_lp_host_ref, + k_t_host_ref, + dq_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n - // dK = scale * dS^T@Q^T - auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m - auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m - ck_tile::reference_batched_gemm( - ds_t_lp_host_ref, - q_t_host_ref, - dk_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m + // dK = scale * dS^T@Q^T + auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m + auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m + ck_tile::reference_batched_gemm( + ds_t_lp_host_ref, + q_t_host_ref, + dk_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m - ck_tile::HostTensor dq_host_result( - {nhead, real_seqlen_q, hdim_q}); // dq_g_m_k - ck_tile::HostTensor dk_host_result( - {nhead, real_seqlen_k, hdim_q}); // dk_g_n_k - ck_tile::HostTensor dv_host_result( - {nhead, real_seqlen_k, hdim_v}); // dv_g_n_o - ck_tile::HostTensor dbias_host_result( - {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n + ck_tile::HostTensor dq_host_result( + {nhead, real_seqlen_q, hdim_q}); // dq_g_m_k + ck_tile::HostTensor dk_host_result( + {nhead, real_seqlen_k, hdim_q}); // dk_g_n_k + ck_tile::HostTensor dv_host_result( + {nhead, real_seqlen_k, hdim_v}); // dv_g_n_o + ck_tile::HostTensor dbias_host_result( + {nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n - // clang-format off + // clang-format off // permute if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); }); else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); }); @@ -932,49 +949,90 @@ bool run(const ck_tile::ArgParser& arg_parser) if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); }); else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); }); } - // clang-format on + // clang-format on - auto [rtol, atol] = get_elimit(hdim_q, hdim_v); - bool dq_cur_pass = ck_tile::check_err(dq_host_result, - dq_host_ref, - std::string("Error: QGrad Incorrect results!"), - rtol, - atol); - bool dk_cur_pass = ck_tile::check_err(dk_host_result, - dk_host_ref, - std::string("Error: KGrad Incorrect results!"), - rtol, - atol); - bool dv_cur_pass = ck_tile::check_err(dv_host_result, - dv_host_ref, - std::string("Error: VGrad Incorrect results!"), - rtol, - atol); + auto [rtol, atol] = get_elimit(hdim_q, hdim_v); + bool dq_cur_pass = ck_tile::check_err(dq_host_result, + dq_host_ref, + std::string("Error: QGrad Incorrect results!"), + rtol, + atol); + bool dk_cur_pass = ck_tile::check_err(dk_host_result, + dk_host_ref, + std::string("Error: KGrad Incorrect results!"), + rtol, + atol); + bool dv_cur_pass = ck_tile::check_err(dv_host_result, + dv_host_ref, + std::string("Error: VGrad Incorrect results!"), + rtol, + atol); - bool dbias_cur_pass = true; - if(use_dbias) - { - dbias_cur_pass = ck_tile::check_err(dbias_host_result, - dbias_host_ref, - std::string("Error: BiasGrad Incorrect results!"), - rtol, - atol); + bool dbias_cur_pass = true; + if(use_dbias) + { + dbias_cur_pass = + ck_tile::check_err(dbias_host_result, + dbias_host_ref, + std::string("Error: BiasGrad Incorrect results!"), + rtol, + atol); + } + pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass); + if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass)) + { + std::cerr << "mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } } - pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass); - if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass)) - { - std::cerr << "mismatch found at batch: " << wb << std::endl - << "\tseqlen_q: " << real_seqlen_q << std::endl - << "\tseqlen_k: " << real_seqlen_k << std::endl - << "\tseqstart_q: " << seqstart_q_host << std::endl - << "\tseqstart_k: " << seqstart_k_host << std::endl; - break; - } + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } - std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; - + if(arg_parser.get_int("json") == 1) + { + dump_fmha_bwd_json_results( + arg_parser.get_str("jsonfile"), + data_type, + mode == mode_enum::batch ? "batch" : "group", + i_perm ? "true" : "false", + o_perm ? "true" : "false", + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + scale, + bias.type == bias_enum::elementwise_bias + ? "elementwise_bias" + : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), + use_dbias ? "true" : "false", + p_drop, + s_randval, + deterministic, + mask.type == mask_enum::no_mask + ? "no_mask" + : (mask.type == mask_enum::window_generic + ? "window_generic" + : (mask.type == mask_enum::mask_top_left + ? "mask_top_left" + : (mask.type == mask_enum::mask_bottom_right ? "mask_bottom_right" + : "mask_generic"))), + mask.left, + mask.right, + workspace_size, + pass, + ave_time, + tflops, + gb_per_sec); + } return pass; } diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index d0f8e3798c..f6b5b879bd 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -7,6 +7,7 @@ #include "mask.hpp" #include "rotary.hpp" #include "utils.hpp" +#include "json_dump.hpp" #include #include @@ -138,7 +139,9 @@ auto create_args(int argc, char* argv[]) .insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe") .insert("cache_batch_idx", "0", "whether to use index map to the kvcache") .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "20", "number of iterations to benchmark the kernel"); + .insert("repeat", "20", "number of iterations to benchmark the kernel") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "fmha_fwd.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -1137,12 +1140,12 @@ bool run(const ck_tile::ArgParser& arg_parser) << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec << " GB/s" << std::flush << std::endl; + bool pass = true; if(do_validation == 0) { std::cout << std::flush << std::endl; - return true; } - if(do_validation == 2) + else if(do_validation == 2) { // NOTE: use gpu to do validation ck_tile::naive_attention_fwd_traits naive_t; @@ -1188,64 +1191,67 @@ bool run(const ck_tile::ArgParser& arg_parser) o_buf.FromDevice(o_host.data()); // TODO: ugly auto [rtol_, atol_] = get_elimit(init_method); - bool pass_ = ck_tile::check_err( + pass = ck_tile::check_err( o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_); - std::cout << ", valid:" << (pass_ ? "y" : "n") << std::flush << std::endl; - return pass_; + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } - - o_buf.FromDevice(o_host.data()); - lse_buf.FromDevice(lse_host.data()); - randval_buf.FromDevice(randval_host.data()); - - auto p_compute_element_func = [&]() { - if constexpr(std::is_same_v) - return ck_tile::scales{scale_p}; - else - return ck_tile::identity{}; - }(); - - auto oacc_element_func = [&]() { - if constexpr(std::is_same_v) - return ck_tile::composes(ck_tile::saturates{}, - ck_tile::scales{scale_o}); - else - return ck_tile::identity{}; - }(); - - float p_undrop = 1.0 - p_drop; - uint8_t p_undrop_in_uint8_t = - uint8_t(std::floor(p_undrop * std::numeric_limits::max())); - float rp_undrop = 1.0 / p_undrop; - - bool pass = true; - for(ck_tile::index_t wb = 0; wb < batch; ++wb) + else { - const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; - // adjust matrix index according to the mode - const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); - const ck_tile::index_t cache_b_idx = - (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); - const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); - const ck_tile::index_t key_offset = - (mode == mode_enum::batch - ? 0 - : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb])); + o_buf.FromDevice(o_host.data()); + lse_buf.FromDevice(lse_host.data()); + randval_buf.FromDevice(randval_host.data()); - ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); - ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); - ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); - ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); + auto p_compute_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::scales{scale_p}; + else + return ck_tile::identity{}; + }(); - ck_tile::HostTensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); - ck_tile::HostTensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); - ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); + auto oacc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{scale_o}); + else + return ck_tile::identity{}; + }(); - ck_tile::index_t nr = nhead / nhead_k; + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; - // clang-format off + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t cache_b_idx = + (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); + const ck_tile::index_t query_offset = + (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = + (mode == mode_enum::batch + ? 0 + : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] + : seqstart_k_with_padding_host[wb])); + + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); + + ck_tile::HostTensor s_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); + + ck_tile::index_t nr = nhead / nhead_k; + + // clang-format off // permute if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); }); else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); }); @@ -1379,198 +1385,179 @@ bool run(const ck_tile::ArgParser& arg_parser) }); } #endif - // clang-format on - - // reference - ck_tile::reference_batched_gemm( - q_host_ref, - k_host_ref, - s_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales(scale_s)); - - if(0.f < logits_soft_cap) - { - ck_tile::reference_unary_elementwise( - s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) { - return ck_tile::type_convert( - logits_soft_cap * - std::tanhf(ck_tile::type_convert(logits / logits_soft_cap))); - }); - } - - if(bias.type == bias_enum::elementwise_bias) - { - // elementwise bias - ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); - // clang-format off - if(i_perm) - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); - else - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); // clang-format on - // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, - // real_seqlen_k] - ck_tile::reference_batched_elementwise( - s_host_ref, bias_host_ref, s_host_ref); - } - else if(bias.type == bias_enum::alibi) - { - // alibi construct elementwise bias to verify - auto alibi_host = [&]() { - if(mask.type != mask_enum::no_mask) - { - return ck_tile::make_alibi_from_lr_mask( - 0, - mask.left, - mask.right, - real_seqlen_q, - real_seqlen_k, - static_cast(mask.type)); - } - else - { - return ck_tile::Alibi{ - 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; - } - }(); + // reference + ck_tile:: + reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s)); - ck_tile::HostTensor alibi_bias_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); - auto i_b_slope = bias.rank_info == 0 ? 0 : wb; - for(auto i_h = 0; i_h < nhead; i_h++) + if(0.f < logits_soft_cap) { - SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); - alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope - : -current_slope; - for(auto i_r = 0; i_r < real_seqlen_q; i_r++) - { - for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + ck_tile::reference_unary_elementwise( + s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) { + return ck_tile::type_convert( + logits_soft_cap * + std::tanhf(ck_tile::type_convert(logits / logits_soft_cap))); + }); + } + + if(bias.type == bias_enum::elementwise_bias) + { + // elementwise bias + ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) { - SaccDataType pixel = 0; - alibi_host.update(pixel, i_r, i_c); - alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL + ? current_slope + : -current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + SaccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } } } + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, alibi_bias_host_ref, s_host_ref); } - // [nhead, real_seqlen_q, real_seqlen_k] - ck_tile::reference_batched_elementwise( - s_host_ref, alibi_bias_host_ref, s_host_ref); - } - if(mask.type == mask_enum::no_mask) - { - ck_tile::reference_batched_masking( - s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); - } - else if(mask.type == mask_enum::window_generic) - { - ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, real_seqlen_q, real_seqlen_k)); - } - else - { - // if left window size is negative, means causal - // else means generic (for current batch) - if(mask.left < 0) + if(mask.type == mask_enum::no_mask) + { ck_tile::reference_batched_masking( - s_host_ref, - ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - real_seqlen_q, - real_seqlen_k, - mask.type == mask_enum::mask_top_left)); - else + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { ck_tile::reference_batched_masking( s_host_ref, ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, - mask.right, - real_seqlen_q, - real_seqlen_k, - mask.type == mask_enum::mask_top_left)); - } - if(lse) - { - ck_tile::reference_batched_softmax( - s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); - } - else - { - ck_tile::reference_batched_softmax( - s_host_ref, p_host_ref, p_compute_element_func); - } + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + } + if(lse) + { + ck_tile:: + reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); + } + else + { + ck_tile:: + reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func); + } - if(p_drop > 0) - { - ck_tile::HostTensor randval_host_ref( - {nhead, real_seqlen_q, real_seqlen_k}); - randval_host_ref.ForEach([&](auto& self, auto idx) { - self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); - }); - ck_tile::reference_batched_dropout( - p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); - } + if(p_drop > 0) + { + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); + }); + ck_tile::reference_batched_dropout( + p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + } - ck_tile::reference_batched_gemm( - p_host_ref, - v_host_ref, - o_host_ref, - ck_tile::identity{}, - ck_tile::identity{}, - oacc_element_func); + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); - ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); - // clang-format off + ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); + // clang-format off // permute if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); - // clang-format on - - auto [rtol, atol] = get_elimit(init_method); - bool cur_pass = ck_tile::check_err( - o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); - pass &= cur_pass; - if(!cur_pass) - { - std::cerr << "OUT mismatch found at batch: " << wb << std::endl - << "\tseqlen_q: " << real_seqlen_q << std::endl - << "\tseqlen_k: " << real_seqlen_k << std::endl - << "\tseqstart_q: " << seqstart_q_host << std::endl - << "\tseqstart_k: " << seqstart_k_host << std::endl; - - break; - } - - if(lse) - { - ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); - lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); - }); - - cur_pass = ck_tile::check_err(lse_host_result, - lse_host_ref, - "LSE Error: Incorrect results!", - rtol, - atol, - /* allow_infinity_ref = */ true); + // clang-format on + auto [rtol, atol] = get_elimit(init_method); + bool cur_pass = ck_tile::check_err(o_host_result, + o_host_ref, + std::string("OUT Error: Incorrect results!"), + rtol, + atol); pass &= cur_pass; if(!cur_pass) { - std::cerr << "LSE mismatch found at batch: " << wb << std::endl + std::cerr << "OUT mismatch found at batch: " << wb << std::endl << "\tseqlen_q: " << real_seqlen_q << std::endl << "\tseqlen_k: " << real_seqlen_k << std::endl << "\tseqstart_q: " << seqstart_q_host << std::endl @@ -1578,10 +1565,65 @@ bool run(const ck_tile::ArgParser& arg_parser) break; } + + if(lse) + { + ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); + lse_host_result.ForEach([&](auto& self, auto idx) { + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); + }); + + cur_pass = ck_tile::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); + + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "LSE mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } - std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + if(arg_parser.get_int("json") == 1) + { + dump_fmha_fwd_json_results(arg_parser.get_str("jsonfile"), + prec, + mode == mode_enum::batch ? "batch" : "group", + io_layout(i_perm, o_perm), + batch, + nhead, + nhead_k, + seqlen_qs[0], + seqlen_ks[0], + seqlen_kpads[0], + hdim_q, + hdim_v, + scale_s, + p_drop, + lse, + squant, + bias.type == bias_enum::elementwise_bias + ? "elementwise_bias" + : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), + vlayout, + pass, + ave_time, + tflops, + gb_per_sec); + } return pass; } diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md index da74e2e3c1..3de48263f8 100644 --- a/example/ck_tile/02_layernorm2d/README.md +++ b/example/ck_tile/02_layernorm2d/README.md @@ -65,6 +65,8 @@ args: -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) -warmup cold iter (default:5) -repeat hot iter (default:20) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:layernorm2d_fwd.json) ``` diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index bdd5f2da1b..94e4734fb4 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -1,5 +1,6 @@ #include "ck_tile/host.hpp" #include "layernorm2d_fwd.hpp" +#include "json_dump.hpp" #include #include @@ -53,7 +54,9 @@ auto create_args(int argc, char* argv[]) .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("warmup", "5", "cold iter") - .insert("repeat", "20", "hot iter"); + .insert("repeat", "20", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "layernorm2d_fwd.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -405,6 +408,24 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } + if(arg_parser.get_int("json") == 1) + { + dump_layernorm2d_fwd_json_results(arg_parser.get_str("jsonfile"), + prec_i, + prec_o, + prec_sm, + prec_sy, + m, + n, + x_stride, + xr_stride, + y_stride, + yr_stride, + pass, + ave_time, + 0, + gb_per_sec); + } return pass; } diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index 6358b76fd9..f4e0bb696c 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -9,11 +9,11 @@ mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank ../script/cmake-ck-dev.sh ../ # The basic pipeline method on the gemm calculation -make tile_example_gemm_basic -j +make tile_example_gemm_basic -j`nproc` # The memory bound pipeline on the gemm calculation -make tile_example_gemm_universal -j +make tile_example_gemm_universal -j`nproc` # The weight preshuffle pipeline on the gemm calculation -make tile_example_gemm_weight_preshuffle -j +make tile_example_gemm_weight_preshuffle -j`nproc` ``` This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal` @@ -30,11 +30,13 @@ args: -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) - -prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16) - -warmup number of iterations before benchmark the kernel (default:10) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:50) -repeat number of iterations to benchmark the kernel (default:100) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) -split_k splitK value (default:1) - -init 0:random, 1:linear, 2:constant (default:1) + -init 0:random, 1:linear, 2:constant(1) (default:0) -persistent 0:non-persistent, 1:persistent (default:0) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:gemm.json) ``` diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 7f2af946e6..b41257c9e0 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -9,6 +9,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +#include "json_dump.hpp" #define CK_TILE_PIPELINE_COMPUTE_V3 1 #define CK_TILE_PIPELINE_MEMORY 2 @@ -493,6 +494,8 @@ auto create_args(int argc, char* argv[]) .insert("split_k", "1", "splitK value") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") .insert("persistent", "0", "0:non-persistent, 1:persistent") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "gemm.json", "json file name to dump results") .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true") .insert("rotating_count", "1000", "rotating count, defaults to 1000"); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 229771e536..0ec08ee16b 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -1,7 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once - template static constexpr inline auto is_row_major(Layout layout_) { @@ -236,23 +235,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count}); } - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_byte = - sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "Run Gemm kernel with \n M=" << M << " N=" << N << " K=" << K - << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C - << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name - << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name - << " B_Type=" << DataTypeTraits::name - << " C_Type=" << DataTypeTraits::name - << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") - << " Persistent=" << (persistent ? "on" : "off") << " : \n" - << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << std::endl; - return ave_time; } @@ -416,32 +398,49 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm, - AccDataType, - CDataType, - ALayout, - BLayout, - ck_tile::tuple<>, - CLayout>(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat, - persistent, - flush_cache, - rotating_count); + float ave_time = invoke_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat, + persistent, + flush_cache, + rotating_count); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K + << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name + << " B_Type=" << DataTypeTraits::name + << " C_Type=" << DataTypeTraits::name + << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") + << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + bool pass = true; // memory on host to store gpu reference result @@ -496,5 +495,28 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU"); } + if(arg_parser.get_int("json") == 1) + { + dump_gemm_json_results(arg_parser.get_str("jsonfile"), + M, + N, + K, + stride_A, + stride_B, + stride_C, + persistent, + pass, + ave_time, + tflops, + gb_per_sec); + } + return pass; } diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp index a110c2f98d..7bae39b9d5 100644 --- a/example/ck_tile/05_reduce/reduce.cpp +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -3,8 +3,24 @@ #include "ck_tile/host.hpp" #include "ck_tile/ops/reduce.hpp" +#include "json_dump.hpp" #include +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -14,8 +30,10 @@ auto create_args(int argc, char* argv[]) .insert("c", "512", "c dimension") .insert("v", "1", "cpu validation or not") .insert("prec", "fp16", "precision") - .insert("warmup", "0", "cold iter") - .insert("repeat", "1", "hot iter"); + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "reduce.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -126,6 +144,12 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; } + if(arg_parser.get_int("json") == 1) + { + dump_reduce_json_results( + arg_parser.get_str("jsonfile"), N, C, H, W, pass, ave_time, 0, gb_per_sec); + } + return pass; } diff --git a/example/ck_tile/06_permute/permute.cpp b/example/ck_tile/06_permute/permute.cpp index aafece0f25..e68fe4bac3 100644 --- a/example/ck_tile/06_permute/permute.cpp +++ b/example/ck_tile/06_permute/permute.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "permute.hpp" #include "ck_tile/host.hpp" @@ -127,7 +127,8 @@ auto create_args(int argc, char* argv[]) "random seed used for initializing input tensors. 0 for " "non-deterministic seed") .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "20", "number of iterations to benchmark the kernel"); + .insert("repeat", "20", "number of iterations to benchmark the kernel") + .insert("jsonfile", "permute.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -382,6 +383,11 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; } + if(arg_parser.get_int("json") == 1) + { + dump_permute_json_results(arg_parser.get_str("jsonfile"), data_type, pass, ave_time, 0, 0); + } + std::cout << std::endl; return pass; diff --git a/example/ck_tile/09_topk_softmax/README.md b/example/ck_tile/09_topk_softmax/README.md index 2e15aeaae5..8bed733d36 100644 --- a/example/ck_tile/09_topk_softmax/README.md +++ b/example/ck_tile/09_topk_softmax/README.md @@ -24,5 +24,7 @@ args: -st_o row stride of output/indices, -1 means same as topk (default:-1) -seed seed to be used, -1 means random every time (default:-1) -kname when set to 1 it will print kernel name (default:0) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:topk_softmax.json) ``` diff --git a/example/ck_tile/09_topk_softmax/topk_softmax.cpp b/example/ck_tile/09_topk_softmax/topk_softmax.cpp index 6fc25631fd..8c1da293a6 100644 --- a/example/ck_tile/09_topk_softmax/topk_softmax.cpp +++ b/example/ck_tile/09_topk_softmax/topk_softmax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -13,6 +13,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/reduce.hpp" #include "topk_softmax_api.hpp" +#include "json_dump.hpp" #if 0 template @@ -130,7 +131,9 @@ auto create_args(int argc, char* argv[]) .insert("seed", "-1", "seed to be used, -1 means random every time") .insert("kname", "0", "when set to 1 it will print kernel name") .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "20", "number of iterations to benchmark the kernel"); + .insert("repeat", "20", "number of iterations to benchmark the kernel") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "topk_softmax.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -273,6 +276,23 @@ bool test_topk_softmax(ck_tile::ArgParser args) } printf("valid:%s\n", rtn ? "y" : "n"); + + if(args.get_int("json") == 1) + { + dump_topk_softmax_json(args.get_str("jsonfile"), + input_prec, + weight_prec, + tokens, + experts, + topk, + stride_input, + stride_output, + ms, + 0, + 0, + rtn); + } + fflush(stdout); return rtn; } diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp index c2bad24cfe..6e6bb20020 100644 --- a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "topk_softmax_api.hpp" diff --git a/example/ck_tile/10_rmsnorm2d/README.md b/example/ck_tile/10_rmsnorm2d/README.md index 1d27ad153e..4f2bc8b5ad 100644 --- a/example/ck_tile/10_rmsnorm2d/README.md +++ b/example/ck_tile/10_rmsnorm2d/README.md @@ -6,17 +6,34 @@ This folder contains example for Rmsnorm2D forward using ck_tile tile-programmin ``` # in the root of ck_tile mkdir build && cd build -../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... -make tile_rmsnorm2d_fwd -j +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_rmsnorm2d_fwd -j`nproc` ``` This will result in an executable `build/bin/tile_rmsnorm2d_fwd` ## cmdline ``` args: - -m m dimension (default:3328) - -n m dimension (default:4096) - -e epsilon (default:1e-5) - -v cpu validation or not (default:1) - -prec precision (default:fp16) + -m m dimension (default:3328) + -n n dimension (default:4096) + -x_stride x row_stride, if -1 then equal to n (default:-1) + -xr_stride x residule row_stride, if -1 then equal to n (default:-1) + -y_stride y row_stride, if -1 then equal to n (default:-1) + -yr_stride y residule row_stride, if -1 then equal to n (default:-1) + -e epsilon (default:1e-5) + -save_rms save rms(invrms) or not. set to 1 in training case (default:0) +-save_unquant save result before quant (default:0) + -v cpu validation or not (default:1) + -kname print kernel name or not (default:1) + -prec_i input precision (default:fp16) + -prec_o output precision, set auto will be the same as input (default:auto) + -prec_sm output quant scale type, set auto will use fp32. used when fquant=1 (default:auto) + -prec_sy output quant scale type, set auto will use fp32. used when fquant=1 or 2 (default:auto) + -fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0) + -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) + -warmup cold iter (default:5) + -repeat hot iter (default:20) + -s sensitive model mode, 0: for no specific model, 1: for T5-like model (default:0) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:rmsnorm2d_fwd.json) ``` diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 751b868411..bf6c4fc68e 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -1,6 +1,7 @@ #include "ck_tile/host.hpp" #include "rmsnorm2d_fwd.hpp" #include +#include "json_dump.hpp" // different threshold for different dtype template @@ -53,7 +54,9 @@ auto create_args(int argc, char* argv[]) .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") .insert("warmup", "5", "cold iter") .insert("repeat", "20", "hot iter") - .insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model"); + .insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "rmsnorm2d_fwd.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -437,6 +440,23 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } + if(arg_parser.get_int("json") == 1) + { + dump_rmsnorm2d_fwd_json(arg_parser.get_str("jsonfile"), + prec_str, + m, + n, + x_stride, + xr_stride, + y_stride, + yr_stride, + use_model_sensitive_rmsnorm, + ave_time, + 0, + gb_per_sec, + pass); + } + return pass; } diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md b/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md index f9ba76c9e3..6c01655b75 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md @@ -6,8 +6,8 @@ This folder contains example for add + Rmsnorm2D + rowwise dynamic quantization ``` # in the root of ck_tile mkdir build && cd build -../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... -make tile_add_rmsnorm2d_rdquant_fwd -j +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_add_rmsnorm2d_rdquant_fwd -j`nproc` ``` This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd` @@ -15,8 +15,16 @@ This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd` ``` args: -m m dimension (default:3328) - -n m dimension (default:4096) + -n n dimension (default:4096) + -stride stride per row, if -1 then equal to n (default:-1) -e epsilon (default:1e-5) + -save_x save rms(invrms) or not. set to 1 in training case (default:1) -v cpu validation or not (default:1) + -kname print kernel name or not (default:1) -prec precision (default:fp16) + -quant precision (default:int8) + -warmup cold iter (default:5) + -repeat hot iter (default:20) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:add_rmsnorm2d_rdquant_fwd.json) ``` diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp index 1cd375d0f5..919767a129 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -1,6 +1,7 @@ #include "ck_tile/host.hpp" #include "add_rmsnorm2d_rdquant_fwd.hpp" #include +#include "json_dump.hpp" // different threshold for different dtype template @@ -41,7 +42,9 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp16", "precision") .insert("quant", "int8", "precision") .insert("warmup", "5", "cold iter") - .insert("repeat", "20", "hot iter"); + .insert("repeat", "20", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "add_rmsnorm2d_rdquant_fwd.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -260,6 +263,21 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } + if(arg_parser.get_int("json") == 1) + { + dump_add_rmsnorm2d_rdquant_fwd_json(arg_parser.get_str("jsonfile"), + input_data_type, + quantized_data_type, + m, + n, + stride, + epsilon, + ave_time, + 0, + gb_per_sec, + pass); + } + return pass; } diff --git a/example/ck_tile/12_smoothquant/README.md b/example/ck_tile/12_smoothquant/README.md index 6b3acd558b..98205e7350 100644 --- a/example/ck_tile/12_smoothquant/README.md +++ b/example/ck_tile/12_smoothquant/README.md @@ -6,8 +6,8 @@ This folder contains example for smoothquant using ck_tile tile-programming impl ``` # in the root of ck_tile mkdir build && cd build -../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... -make tile_smoothquant -j +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_smoothquant -j`nproc` ``` This will result in an executable `build/bin/tile_smoothquant` @@ -15,7 +15,14 @@ This will result in an executable `build/bin/tile_smoothquant` ``` args: -m m dimension (default:3328) - -n m dimension (default:4096) + -n n dimension (default:4096) + -x_stride input stride per row, if -1 then equal to n (default:-1) + -y_stride output stride per row, if -1 then equal to n (default:-1) -v cpu validation or not (default:1) + -kname print kernel name or not (default:1) -prec precision (default:fp16) + -warmup cold iter (default:5) + -repeat hot iter (default:20) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:smoothquant.json) ``` diff --git a/example/ck_tile/12_smoothquant/smoothquant.cpp b/example/ck_tile/12_smoothquant/smoothquant.cpp index 02ab1cd9b1..b54babdce3 100644 --- a/example/ck_tile/12_smoothquant/smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/smoothquant.cpp @@ -1,5 +1,6 @@ #include "ck_tile/host.hpp" #include "smoothquant.hpp" +#include "json_dump.hpp" #include // different threshold for different dtype @@ -39,7 +40,9 @@ auto create_args(int argc, char* argv[]) .insert("kname", "1", "print kernel name or not") .insert("prec", "fp16", "precision") .insert("warmup", "5", "cold iter") - .insert("repeat", "20", "hot iter"); + .insert("repeat", "20", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "smoothquant.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -202,6 +205,19 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } + if(arg_parser.get_int("json") == 1) + { + dump_smoothquant_json(arg_parser.get_str("jsonfile"), + data_type, + m, + n, + x_stride, + y_stride, + ave_time, + 0, + gb_per_sec, + pass); + } return pass; } diff --git a/example/ck_tile/13_moe_sorting/README.md b/example/ck_tile/13_moe_sorting/README.md index c99f40aa57..1fd40aab35 100644 --- a/example/ck_tile/13_moe_sorting/README.md +++ b/example/ck_tile/13_moe_sorting/README.md @@ -6,32 +6,36 @@ This folder contains example for moe-sorting kernel using ck_tile tile-programmi ``` # in the root of ck_tile mkdir build && cd build -../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... -make tile_example_moe_sorting -j +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_moe_sorting -j`nproc` ``` This will result in an executable `build/bin/tile_example_moe_sorting` ## example ``` args: - -v turn CPU validation on (1) or off (0). (default:1) - -pr_i index data type. Only int32 is currently supported. (default:int32) - -pr_w output weight data type. Only fp32 is currently supported. (default:fp32) - -t number of input tokens. (default:128) - If "local_t" presents, this value indicates global concurrency of all ranks. - -local_t Number of local input tokens for curent rank. (default:-1) - This value must be within range "[0, t)", or "-1"(no such feature) - This feature is to simulate EP case where where each rank has different tokens. - Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph. - -e number of num_experts (default:8) - -k topk (default:4) - -unit unit_size (default:32) --moe_buf_size moe_buf_size (default:0) - -local_eid a list of experts enabled as local expert. e.g. "0,1,4,5" (default:-1) - please make sure eid is in ascending order! - -seed seed to be used. When set to -1, a random seed will be generated each time invoking this example (default:-1) - -kname prints the kernel name when set to 1 (default:0) - -warmup number of iterations before benchmark the kernel (default:5) - -repeat number of iterations to benchmark the kernel (default:20) - + -v turn CPU validation on (1) or off (0). (default:1) + -pr_i index data type. Only int32 is currently supported. (default:int32) + -pr_w output weight data type. Only fp32 is currently supported. (default:fp32) + -t number of input tokens. (default:128) + If "local_t" presents, this value indicates global concurrency of all ranks. + -local_t Number of local input tokens for curent rank. (default:-1) + This value must be within range "[0, t)", or "-1"(no such feature) + This feature is to simulate EP case where where each rank has different tokens. + Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph. + -e number of num_experts (default:8) + -k topk (default:4) + -unit unit_size (default:32) +-moe_buf_interm_dim interm_dim(col) of the following fmoe buf (default:0) +-moe_buf_elem_bytes fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit... (default:2) + -ci clear workspace inside API or not(if "0", require manually clear outside) (default:1) + -dispatch dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel (default:0) + -local_eid a list of experts enabled as local expert. e.g. "0,1,4,5" (default:-1) + please make sure eid is in ascending order! + -seed seed to be used. When set to -1, a random seed will be generated each time invoking this example (default:-1) + -kname prints the kernel name when set to 1 (default:0) + -warmup number of iterations before benchmark the kernel (default:5) + -repeat number of iterations to benchmark the kernel (default:20) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:moe_sorting.json) ``` diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index e9b4ea5cd3..ef1edadf7a 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -14,6 +14,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/reduce.hpp" #include "moe_sorting_api.hpp" +#include "json_dump.hpp" auto create_args(int argc, char* argv[]) { @@ -59,7 +60,9 @@ auto create_args(int argc, char* argv[]) "invoking this example") .insert("kname", "0", "prints the kernel name when set to 1") .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "20", "number of iterations to benchmark the kernel"); + .insert("repeat", "20", "number of iterations to benchmark the kernel") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "moe_sorting.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -437,6 +440,23 @@ bool test_moe_sorting(ck_tile::ArgParser args) printf(", (%d)", seed); printf("\n"); fflush(stdout); + + if(args.get_int("json") == 1) + { + dump_moe_sorting_json(args.get_str("jsonfile"), + index_prec, + weight_prec, + workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"), + dispatch_policy, + tokens, + num_experts, + topk, + ms, + 0, + 0, + rtn); + } + return rtn; } diff --git a/example/ck_tile/14_moe_smoothquant/README.md b/example/ck_tile/14_moe_smoothquant/README.md index c10a922607..f675f4bca9 100644 --- a/example/ck_tile/14_moe_smoothquant/README.md +++ b/example/ck_tile/14_moe_smoothquant/README.md @@ -9,7 +9,25 @@ Unlike standard smoothquant op, the input scale is from different expert `[exper ``` # in the root of ck_tile mkdir build && cd build -../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... -make tile_example_moe_smoothquant -j +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_moe_smoothquant -j`nproc` ``` This will result in an executable `build/bin/tile_example_moe_smoothquant` + +## example +``` +args: + -t tokens dimension (default:3328) + -h hidden_size dimension (default:4096) + -e experts (default:32) + -k topk (default:5) + -stride stride per row, if -1 then equal to hidden_size (default:-1) + -v cpu validation or not (default:1) + -kname print kernel name or not (default:1) + -prec_i input precision, fp16/bf16 (default:fp16) + -prec_o precision, int8/fp8 (default:int8) + -warmup cold iter (default:5) + -repeat hot iter (default:20) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:moe_smoothquant.json) +``` \ No newline at end of file diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp index 848fb87dcf..864ff31798 100644 --- a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp @@ -1,5 +1,6 @@ #include "ck_tile/host.hpp" #include "moe_smoothquant.hpp" +#include "json_dump.hpp" #include #include @@ -66,7 +67,9 @@ auto create_args(int argc, char* argv[]) .insert("prec_i", "fp16", "input precision, fp16/bf16") .insert("prec_o", "int8", "precision, int8/fp8") .insert("warmup", "5", "cold iter") - .insert("repeat", "20", "hot iter"); + .insert("repeat", "20", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "moe_smoothquant.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -244,6 +247,21 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } + if(arg_parser.get_int("json")) + { + dump_moe_smoothquant_json(arg_parser.get_str("jsonfile"), + prec_i, + prec_o, + tokens, + hidden_size, + stride, + experts, + topk, + pass, + ave_time, + 0, + gb_per_sec); + } return pass; } diff --git a/example/ck_tile/15_fused_moe/README.md b/example/ck_tile/15_fused_moe/README.md index 089e1de78e..1376149177 100644 --- a/example/ck_tile/15_fused_moe/README.md +++ b/example/ck_tile/15_fused_moe/README.md @@ -69,4 +69,42 @@ summary of the key design of this fused-moe operator: // 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) // // max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) +``` + +## example +``` +args: + -t number of input tokens. (default:128) + If "local_t" presents, this value indicates global concurrency of all ranks. + -local_t Number of local input tokens for curent rank. (default:-1) + This value must be within range "[0, t)", or "-1"(no such feature) + This feature is to simulate EP case where where each rank has different tokens. + Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph. + -e num of experts (default:32) + -k topk (default:5) + -h hidden_size of this model (default:8192) + -i intermediate_size between 2 gemms of FFN (default:8192) + -stride stride per row, if -1 then equal to hidden_size (default:-1) + -bm blocking factor for sorted tokens (default:32) + -tp tensor parallel size (default:8) + -v cpu validation or not (default:1) + -kname print kernel name or not (default:1) + -prec_i input precision (default:bf16) + -prec_w weight precision (default:bf16) + -prec_o output precision (default:bf16) + -prec_st token scale data type. auto will set to fp32 (default:auto) + -prec_sw weight scale data type. auto will set to fp32 (default:auto) + -prec_sq (dynamic) smooth quant data type. auto will set to fp32 (default:auto) + -prec_kw topk-weight data type. auto will set to fp32 (default:auto) + -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) + -gate_only w0(gate/up) style, 0:gate+up will double interm size, 1:only gate (default:1) + -api benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm (default:0) + -act activation after first gemm. 0:gelu, 1:silu (default:0) + -balance if set to 1, will try balance the expert in topk-ids(convenient for testing) (default:0) + -init init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand normalized[0, 1]normalized(slow) (default:1) + -seed seed used to do random (default:11939) + -warmup cold iter (default:5) + -repeat hot iter (default:20) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:fused_moe.json) ``` \ No newline at end of file diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index e4d87e5fef..3c459c6b95 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -5,6 +5,7 @@ #include #include "ck_tile/host.hpp" +#include "json_dump.hpp" #include "fused_moe.hpp" // different threshold for different dtype @@ -130,7 +131,9 @@ auto create_args(int argc, char* argv[]) "normalized(slow)") .insert("seed", "11939", "seed used to do random") .insert("warmup", "5", "cold iter") - .insert("repeat", "20", "hot iter"); + .insert("repeat", "20", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "fused_moe.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -513,6 +516,29 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; } std::cout << std::flush << std::endl; + + if(arg_parser.get_int("json") == 1) + { + dump_fused_moe_json(arg_parser.get_str("jsonfile"), + api_str, + prec_str, + tokens, + is_local_token, + local_tokens, + experts, + topk, + hidden_size, + intermediate_size, + stride, + block_m, + activation, + gate_only, + fused_quant, + pass, + ave_time, + cal_tflops(ave_time), + cal_tbps(ave_time)); + } return pass; } else if(api == 1) @@ -619,6 +645,29 @@ bool run(const ck_tile::ArgParser& arg_parser) } std::cout << std::flush << std::endl; + if(arg_parser.get_int("json") == 1) + { + dump_fused_moe_json(arg_parser.get_str("jsonfile"), + api_str, + prec_str, + tokens, + is_local_token, + local_tokens, + experts, + topk, + hidden_size, + intermediate_size, + stride, + block_m, + activation, + gate_only, + fused_quant, + pass, + ave_time, + cal_tflops(ave_time), + cal_tbps(ave_time)); + } + return pass; } return false; diff --git a/example/ck_tile/16_batched_gemm/README.md b/example/ck_tile/16_batched_gemm/README.md index 8a64a3912c..d82f20eb2b 100644 --- a/example/ck_tile/16_batched_gemm/README.md +++ b/example/ck_tile/16_batched_gemm/README.md @@ -15,23 +15,25 @@ This will result in an executable `build/bin/tile_example_batched_gemm` ## example ``` args: - -m m dimension (default:256) - -n n dimension (default:128) - -k k dimension (default:128) - -a_layout A tensor data layout (default:R) (R for Row, C for Col) - -b_layout B tensor data layout (default:R) (R for Row, C for Col) - -c_layout C tensor data layout (default:R) (R for Row, C for Col) - -stride_a Tensor A stride (default:128) - -stride_b Tensor B stride (default:128) - -stride_c Tensor C stride (default:128) - -batch_stride_a Batch A stride (default:32768) - -batch_stride_b Batch B stride (default:16384) - -batch_stride_c Batch C stride (default:32768) - -batch_count Batch count (default:16) - -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) - -e Absolute error tolerance (default:1e-5) - -prec data type. fp16/bf16/fp8/bf8 (default:fp16) - -warmup number of iterations before benchmark the kernel (default:10) - -repeat number of iterations to benchmark the kernel (default:100) - -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -m m dimension (default:512) + -n n dimension (default:1024) + -k k dimension (default:2048) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -a_layout A tensor data layout - Row by default (default:R) + -b_layout B tensor data layout - Row by default (default:C) + -c_layout C tensor data layout - Row by default (default:R) + -batch_stride_a Batch A stride (default:1048576) + -batch_stride_b Batch B stride (default:2097152) + -batch_stride_c Batch C stride (default:524288) + -batch_count Batch count (default:8) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:50) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -split_k splitK value (default:1) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:cktile_batched_gemm.json) ``` \ No newline at end of file diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index 78d915e873..c11453b0d1 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -9,6 +9,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include #define CK_TILE_PIPELINE_COMPUTE_V3 1 #define CK_TILE_PIPELINE_MEMORY 2 @@ -75,7 +76,9 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("split_k", "1", "splitK value"); + .insert("split_k", "1", "splitK value") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "cktile_batched_gemm.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 6d26cfe675..3289a2836b 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -2,7 +2,6 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once - auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -77,21 +76,6 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, CDEElementWise>( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - std::string op_name{"Batched Gemm"}; - std::size_t flop = std::size_t(2) * batch_count * M * N * K; - std::size_t num_byte = sizeof(ADataType) * batch_count * M * K + - sizeof(BDataType) * batch_count * N * K + - sizeof(CDataType) * batch_count * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K - << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C - << " batch_stride_A =" << batch_stride_A << " batch_stride_B =" << batch_stride_B - << " batch_stride_C =" << batch_stride_C << " batch_count =" << batch_count << " : " - << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << std::endl; - return ave_time; } @@ -186,31 +170,47 @@ int run_batched_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_batched_gemm, - AccDataType, - CDataType, - ALayout, - BLayout, - ck_tile::tuple<>, - CLayout>(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - batch_stride_A, - batch_stride_B, - batch_stride_C, - batch_count, - kbatch, - n_warmup, - n_repeat); + float ave_time = invoke_batched_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count, + kbatch, + n_warmup, + n_repeat); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + std::string op_name{"Batched Gemm"}; + std::size_t flop = std::size_t(2) * batch_count * M * N * K; + std::size_t num_byte = sizeof(ADataType) * batch_count * M * K + + sizeof(BDataType) * batch_count * N * K + + sizeof(CDataType) * batch_count * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " batch_stride_A =" << batch_stride_A << " batch_stride_B =" << batch_stride_B + << " batch_stride_C =" << batch_stride_C << " batch_count =" << batch_count << " : " + << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + bool pass = true; if(arg_parser.get_int("v") == 1) @@ -310,6 +310,27 @@ int run_batched_gemm_example_with_layouts(int argc, std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; } + if(arg_parser.get_int("json") == 1) + { + dump_batched_gemm_json_results(arg_parser.get_str("jsonfile"), + op_name, + M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count, + pass, + ave_time, + tflops, + gb_per_sec, + "batched_gemm"); + } + return pass; } diff --git a/example/ck_tile/17_grouped_gemm/README.md b/example/ck_tile/17_grouped_gemm/README.md index 8715ee79e1..85a02c2231 100644 --- a/example/ck_tile/17_grouped_gemm/README.md +++ b/example/ck_tile/17_grouped_gemm/README.md @@ -157,17 +157,20 @@ This will result in an executable `build/bin/tile_example_grouped_gemm` ## example ``` args: - -Ms M dimensions - (Default: empty). - -Ns N dimensions - (Default: empty). - -Ks K dimensions - (Default: empty). - -stride_As Tensor A strides - (Default: empty). - -stride_Bs Tensor B strides - (Default: empty). - -stride_Cs Tensor C strides - (Default: empty). - -a_layout A tensor data layout - (Default: Row). - -b_layout B tensor data layout - (Default: Col). - -c_layout C tensor data layout - (Default: Row). - -validate 0. No validation, 1. Validation on CPU. (Default: 1). - -warmup Number of iterations before benchmark the kernel. (Default: 10). - -repeat Number of iterations to benchmark the kernel. (Default: 100). - -group_count Group count. (Default: 16). + -Ms M dimensions - empty by default. (default:) + -Ns N dimensions - empty by default. (default:) + -Ks K dimensions - empty by default. (default:) + -stride_As Tensor A strides - it is empty by default. (default:) + -stride_Bs Tensor B strides - it is empty by default. (default:) + -stride_Cs Tensor C strides - it is empty by default. (default:) + -a_layout A tensor data layout - Row by default. (default:R) + -b_layout B tensor data layout - Row by default. (default:C) + -c_layout C tensor data layout - Row by default. (default:R) + -validate 0. No validation, 1. Validation on CPU. (default:1) + -warmup number of iterations before benchmark the kernel. (default:10) + -repeat number of iterations to benchmark the kernel. (default:100) + -group_count group count. (default:8) + -kbatch kbatch for SplitK (default:1) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:grouped_gemm.json) ``` diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index e992cb3118..7a8b670123 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -9,6 +9,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "json_dump.hpp" #define CK_TILE_PIPELINE_COMPUTE_V3 1 #define CK_TILE_PIPELINE_MEMORY 2 @@ -171,7 +172,9 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") .insert("group_count", "8", "group count.") - .insert("kbatch", "1", "kbatch for SplitK"); + .insert("kbatch", "1", "kbatch for SplitK") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "grouped_gemm.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 425299203f..2e1afc3533 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -2,7 +2,6 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once - template static constexpr inline auto is_row_major(Layout layout_) { @@ -114,24 +113,6 @@ float invoke_gemm(int n_warmup, CDataType>(stream, group_count, kargs_ptr, splitk); } - std::string op_name{"Grouped Gemm"}; - - std::size_t flop = 0, num_btype = 0; - for(int j = 0; j < group_count; ++j) - { - flop += std::size_t(2) * args[j].M * args[j].N * args[j].K; - - num_btype += sizeof(ADataType) * args[j].M * args[j].K + - sizeof(BDataType) * args[j].K * args[j].N + - sizeof(CDataType) * args[j].M * args[j].N; - } - - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - return ave_time; } @@ -259,17 +240,34 @@ int run_grouped_gemm_example_with_layouts(int argc, {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } - invoke_gemm, - AccDataType, - CDataType, - ALayout, - BLayout, - ck_tile::tuple<>, - CLayout, - Persistent>(warmup, repeat, group_count, gemm_descs); + float ave_time = invoke_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout, + Persistent>(warmup, repeat, group_count, gemm_descs); + + std::string op_name{"Grouped Gemm"}; + + std::size_t flop = 0, num_btype = 0; + for(int j = 0; j < group_count; ++j) + { + flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K; + + num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K + + sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N + + sizeof(CDataType) * gemm_descs[j].M * gemm_descs[j].N; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; for(int i = 0; i < group_count; i++) { @@ -304,6 +302,17 @@ int run_grouped_gemm_example_with_layouts(int argc, std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } + if(arg_parser.get_int("json") == 1) + { + dump_grouped_gemm_json_results(arg_parser.get_str("jsonfile"), + op_name, + group_count, + pass, + ave_time, + tflops, + gb_per_sec); + } + return pass; } diff --git a/example/ck_tile/18_flatmm/README.md b/example/ck_tile/18_flatmm/README.md index eeaa7658bd..c58700fc7b 100644 --- a/example/ck_tile/18_flatmm/README.md +++ b/example/ck_tile/18_flatmm/README.md @@ -16,20 +16,23 @@ This will result in an executable `build/bin/tile_example_flatmm_basic` ## example ``` args: - -b batch size (default:1) - -m m dimension (default:1024) - -n n dimension (default:2048) - -k k dimension (default:64) - -a_layout Tensor A data layout (default: R) - -b_layout Tensor B data layout (default: R) - -c_layout Tensor C data layout (default: R) + -m m dimension (default:256) + -n n dimension (default:256) + -k k dimension (default:128) + -a_layout A tensor data layout - Row by default (default:R) + -b_layout B tensor data layout - Row by default (default:C) + -c_layout C tensor data layout - Row by default (default:R) -stride_a Tensor A stride (default:0) -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) - -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) - -e Absolute error tolerance (default:1e-5) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1) -prec data type. fp16/bf16/fp8/bf8 (default:fp16) - -warmup number of iterations before benchmark the kernel (default:10) + -warmup number of iterations before benchmark the kernel (default:50) -repeat number of iterations to benchmark the kernel (default:100) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -split_k splitK value (default:1) + -init 0:random, 1:linear, 2:constant(1) (default:0) + -warp_tile 0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only) (default:0) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:flatmm_basic.json) ``` diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 963a6ba675..64e141860e 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -183,9 +183,10 @@ auto create_args(int argc, char* argv[]) .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "splitK value") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") - .insert("warp_tile", - "0", - "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)"); + .insert( + "warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "flatmm_basic.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index ff1a239cba..e526ddc3f5 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -2,7 +2,7 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include - +#include "json_dump.hpp" template constexpr const char* DataTypeToString() { @@ -140,17 +140,6 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, CDEElementWise>( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_byte = - sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString() - << " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A - << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time - << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; - return ave_time; } @@ -242,27 +231,38 @@ int run_flatmm_example_with_layouts(int argc, ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); - invoke_flatmm, - AccDataType, - CDataType, - ALayout, - BLayout, - ck_tile::tuple<>, - CLayout>(a_dev_buf, - b_shuffle_dev_buf, - c_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat); + float ave_time = invoke_flatmm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_dev_buf, + b_shuffle_dev_buf, + c_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString() + << " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A + << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time + << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; c_dev_buf.FromDevice(c_rslt_host.data()); bool pass = true; @@ -350,5 +350,22 @@ int run_flatmm_example_with_layouts(int argc, std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; } + if(arg_parser.get_int("json") == 1) + { + dump_flatmm_json_results(arg_parser.get_str("jsonfile"), + DataTypeToString(), + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + pass, + ave_time, + tflops, + gb_per_sec); + } + return pass; } diff --git a/example/ck_tile/19_gemm_multi_d/README.md b/example/ck_tile/19_gemm_multi_d/README.md index 2cf2b1ea03..b9416f3112 100644 --- a/example/ck_tile/19_gemm_multi_d/README.md +++ b/example/ck_tile/19_gemm_multi_d/README.md @@ -17,19 +17,21 @@ This will result in an executable `build/bin/tile_example_gemm_multi_d_fp16` ## example ``` args: - -m M dimensions - (Default: 3840) - -n N dimensions - (Default: 4096) - -k K dimensions - (Default: 4096) --a_layout Tensor A layout (default:R) --b_layout Tensor B layout (default:C) --ds_layout Tensor D layout (default:R) --e_layout Tensor E layout (default:R) --stride_a Tensor A strides - (Default: 0) --stride_b Tensor B strides - (Default: 0) --stride_e Tensor C strides - (Default: 0) --stride_ds Tensor D strides - (Default: 0) --validate 0. No validation, 1. Validation on GPU. (Default: 1) - -warmup Number of iterations before benchmark the kernel. (Default: 10) - -repeat Number of iterations to benchmark the kernel. (Default: 100) - -kbatch kbatch for SplitK. (Default 1) + -m m dimension (default:3840) + -n n dimension (default:4096) + -k k dimension (default:4096) + -a_layout A tensor data layout - Row by default (default:R) + -b_layout B tensor data layout - Col by default (default:C) + -ds_layout Ds tensor data layout - Row by default (default:R) + -e_layout E tensor data layout - Row by default (default:R) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_ds Tensor Ds stride (default:0) + -stride_e Tensor E stride (default:0) + -v 0. No validation, 1. Validation on GPU (default:1) + -warmup number of iterations before benchmark the kernel (default:50) + -repeat number of iterations to benchmark the kernel (default:100) + -kbatch kbatch for SplitK (default:1) + -json 0: No Json, 1: Dump Results in Json format (default:0) + -jsonfile json file name to dump results (default:cktile_gemm_multi_d_fp16.json) ``` diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp index 87b9592553..d28f823eda 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp @@ -58,7 +58,9 @@ auto create_args(int argc, char* argv[]) .insert("v", "1", "0. No validation, 1. Validation on GPU") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("kbatch", "1", "kbatch for SplitK"); + .insert("kbatch", "1", "kbatch for SplitK") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "cktile_gemm_multi_d_fp16.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/19_gemm_multi_d/run_gemm_multi_d_fp16_example.inc b/example/ck_tile/19_gemm_multi_d/run_gemm_multi_d_fp16_example.inc index a0d7157d03..2b388768b0 100644 --- a/example/ck_tile/19_gemm_multi_d/run_gemm_multi_d_fp16_example.inc +++ b/example/ck_tile/19_gemm_multi_d/run_gemm_multi_d_fp16_example.inc @@ -3,6 +3,7 @@ #pragma once #include +#include "json_dump.hpp" template ( gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - std::string op_name{"Gemm Multiple-D"}; - static constexpr ck_tile::index_t NumDTensor = DsDataType::size(); - - std::size_t flop = 0, num_btype = 0; - - flop += std::size_t(2) * M * N * K; - - ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) { - num_btype += sizeof(ck_tile::remove_cvref_t>) * M * N; - flop += sizeof(ck_tile::remove_cvref_t>) * M * N; - }); - - num_btype += sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Run Gemm Multiple-D kernel with:\n"; - std::cout << "M =" << M << " N =" << N << " K =" << K << "\n"; - std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideE = " << StrideE - << "\n"; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << "\n"; - return ave_time; } @@ -159,29 +136,53 @@ int run_multiple_d_gemm_example_with_layouts(int argc, std::array stridesDs = {StrideD0, StrideD1}; - invoke_gemm_multi_d(a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - ds_ptr_buf, - e_m_n_dev_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - stridesDs, - StrideE, - n_warmup, - n_repeat, - k_batch); + float ave_time = invoke_gemm_multi_d(a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + stridesDs, + StrideE, + n_warmup, + n_repeat, + k_batch); + + std::string op_name{"Gemm Multiple-D"}; + static constexpr ck_tile::index_t NumDTensor = DsDataType::size(); + + std::size_t flop = 0, num_btype = 0; + + flop += std::size_t(2) * M * N * K; + + ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) { + num_btype += sizeof(ck_tile::remove_cvref_t>) * M * N; + flop += sizeof(ck_tile::remove_cvref_t>) * M * N; + }); + + num_btype += sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Run Gemm Multiple-D kernel with:\n"; + std::cout << "M =" << M << " N =" << N << " K =" << K << "\n"; + std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideE = " << StrideE + << "\n"; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << "\n"; e_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); @@ -217,6 +218,24 @@ int run_multiple_d_gemm_example_with_layouts(int argc, << std::endl; std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl; } + + if(arg_parser.get_int("json") == 1) + { + dump_gemm_multi_d_fp16_json_results(arg_parser.get_str("jsonfile"), + op_name, + M, + N, + K, + StrideA, + StrideB, + StrideD0, + StrideD1, + StrideE, + pass, + ave_time, + tflops, + gb_per_sec); + } return pass; } diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index f3a7a60fd9..2bc33b9b02 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -120,7 +120,8 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "splitK value") - .insert("init", "0", "0:random, 1:linear, 2:constant(1)"); + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/21_elementwise/elementwise_example.cpp b/example/ck_tile/21_elementwise/elementwise_example.cpp index 2cc539e117..879f3db141 100644 --- a/example/ck_tile/21_elementwise/elementwise_example.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example.cpp @@ -5,6 +5,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/ops/elementwise.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp" +#include "json_dump.hpp" auto create_args(int argc, char* argv[]) { @@ -15,7 +16,9 @@ auto create_args(int argc, char* argv[]) .insert("v", "1", "cpu validation or not") .insert("prec", "fp16", "precision") .insert("warmup", "10", "cold iter") - .insert("repeat", "50", "hot iter"); + .insert("repeat", "50", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "elementwise.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -195,6 +198,18 @@ bool run(const ck_tile::ArgParser& arg_parser) y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01); } + if(arg_parser.get_int("json") == 1) + { + dump_elementwise_json_results(arg_parser.get_str("jsonfile"), + arg_parser.get_str("prec"), + kGridSize, + kBlockSize, + ave_time, + 0, + 0, + "elementwise_add"); + } + return pass; } diff --git a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp index 7087d092a2..82dc59c2fb 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp @@ -5,6 +5,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/ops/elementwise.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp" +#include "json_dump.hpp" auto create_args(int argc, char* argv[]) { @@ -16,7 +17,9 @@ auto create_args(int argc, char* argv[]) .insert("v", "1", "cpu validation or not") .insert("prec", "fp16", "precision") .insert("warmup", "10", "cold iter") - .insert("repeat", "50", "hot iter"); + .insert("repeat", "50", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "elementwise_add_4d.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -140,6 +143,18 @@ bool run(const ck_tile::ArgParser& arg_parser) y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01); } + if(arg_parser.get_int("json") == 1) + { + dump_elementwise_json_results(arg_parser.get_str("jsonfile"), + arg_parser.get_str("prec"), + kGridSize, + kBlockSize, + ave_time, + 0, + 0, + "elementwise_add_4d"); + } + return pass; } diff --git a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp index 28cdaf27b9..a1825927c0 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp @@ -4,6 +4,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/ops/elementwise.hpp" #include "ck_tile/host/reference/reference_transpose.hpp" +#include "json_dump.hpp" auto create_args(int argc, char* argv[]) { @@ -14,7 +15,9 @@ auto create_args(int argc, char* argv[]) .insert("v", "1", "cpu validation or not") .insert("prec", "fp16", "precision") .insert("warmup", "10", "cold iter") - .insert("repeat", "50", "hot iter"); + .insert("repeat", "50", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "elementwise_transpose.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -137,6 +140,18 @@ bool run(const ck_tile::ArgParser& arg_parser) y_validation, y_host, "Transpose Error: Incorrect results!", 0.01, 0.01); } + if(arg_parser.get_int("json") == 1) + { + dump_elementwise_json_results(arg_parser.get_str("jsonfile"), + arg_parser.get_str("prec"), + kGridSize, + kBlockSize, + ave_time, + 0, + 0, + "elementwise_transpose"); + } + return pass; } diff --git a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp index 782d3da24d..4fd36980ef 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp @@ -5,6 +5,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/ops/elementwise.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp" +#include "json_dump.hpp" auto create_args(int argc, char* argv[]) { @@ -15,7 +16,9 @@ auto create_args(int argc, char* argv[]) .insert("v", "1", "cpu validation or not") .insert("prec", "fp16", "precision") .insert("warmup", "10", "cold iter") - .insert("repeat", "50", "hot iter"); + .insert("repeat", "50", "hot iter") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "elementwise_unary.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -127,6 +130,18 @@ bool run(const ck_tile::ArgParser& arg_parser) y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01); } + if(arg_parser.get_int("json") == 1) + { + dump_elementwise_json_results(arg_parser.get_str("jsonfile"), + arg_parser.get_str("prec"), + kGridSize, + kBlockSize, + ave_time, + 0, + 0, + "elementwise_unary"); + } + return pass; } diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp index 571386694b..74441a149a 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp @@ -12,6 +12,7 @@ #include "batched_transpose_example.hpp" +#include "json_dump.hpp" #if 0 template void dump_host_tensor_4d(const ck_tile::HostTensor& x) @@ -103,6 +104,8 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("seed", "-1", "seed to be used, -1 means random every time") .insert("kname", "0", "t to 1 will print kernel name") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "batched_transpose.json", "json file name to dump results") .insert("pipeline", "0", "0: no LDS usage, 1: LDS-accelerated (gfx950)"); bool result = arg_parser.parse(argc, argv); @@ -236,6 +239,23 @@ bool run_batched_transpose(ck_tile::ArgParser args) "--------------------------------------------------------------------\n", rtn ? "y" : "n"); fflush(stdout); + + if(args.get_int("json") == 1) + { + dump_batched_transpose_json(args.get_str("jsonfile"), + N, + C, + H, + W, + layout_in, + layout_out, + prec, + ms, + 0, + gb_per_sec, + rtn); + } + return rtn; } diff --git a/example/include/json_dump.hpp b/example/include/json_dump.hpp new file mode 100644 index 0000000000..05d6a66024 --- /dev/null +++ b/example/include/json_dump.hpp @@ -0,0 +1,700 @@ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant" +#include "rapidjson/writer.h" +#include "rapidjson/stringbuffer.h" +#include "rapidjson/document.h" +#include "rapidjson/rapidjson.h" +// #include +#pragma GCC diagnostic pop + +#define START_JSON_DUMP_FILE(file_name) \ + std::string file_str(file_name); \ + std::ofstream file(file_str); \ + if(!file.is_open()) \ + { \ + throw std::runtime_error("Could not open file: " + std::string(file_name)); \ + } \ + rapidjson::StringBuffer s; \ + rapidjson::Writer writer(s); \ + writer.StartObject(); + +#define END_JSON_DUMP_FILE() \ + writer.EndObject(); \ + file << s.GetString(); \ + file.close(); \ + std::cout << "Results written to " << file_str << " successfully" << std::endl; + +#define ADD_KEY_VALUE(key, value) add_key_value_pair(writer, key, value); +#define ADD_PERF_TO_JSON(_time, tflops, gbytes) add_perf_to_json(writer, _time, tflops, gbytes); + +template +void add_key_value_pair(rapidjson::Writer& writer, + const char* key, + T value) +{ + writer.Key(key); + if constexpr(std::is_same::value) + { + writer.String(value, static_cast(std::strlen(value))); + } + else if constexpr(std::is_same::value) + { + writer.String(value.c_str(), static_cast(value.length())); + } + else if constexpr(std::is_floating_point::value) + { + writer.Double(static_cast(value)); + } + else if constexpr(std::is_integral::value) + { + writer.Int64(static_cast(value)); + } + else + { + static_assert(std::is_same::value || std::is_floating_point::value || + std::is_integral::value, + "Unsupported type for JSON serialization"); + } +} + +static void add_perf_to_json(rapidjson::Writer& writer, + float time, + float tflops, + float gbytes) +{ + std::string roster("perf"); + writer.String(roster.c_str(), static_cast(roster.length())); + + writer.StartArray(); + writer.StartObject(); + + add_key_value_pair(writer, "time", time); + add_key_value_pair(writer, "tflops", tflops); + add_key_value_pair(writer, "gbytes", gbytes); + + writer.EndObject(); + writer.EndArray(); +} + +// Helper traits to check for static member existence +template +struct has_warp_tile_members : std::false_type +{ +}; + +template +struct has_warp_tile_members< + T, + std::void_t> + : std::true_type +{ +}; + +template + typename DTypeTraits> +void dump_gemm_json_results(const std::string& json_filename, + int M, + int N, + int K, + int stride_A, + int stride_B, + int stride_C, + bool persistent, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "gemm_basic") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("M", M); + ADD_KEY_VALUE("N", N); + ADD_KEY_VALUE("K", K); + ADD_KEY_VALUE("stride_A", stride_A); + ADD_KEY_VALUE("stride_B", stride_B); + ADD_KEY_VALUE("stride_C", stride_C); + ADD_KEY_VALUE("A_layout", ALayout::name); + ADD_KEY_VALUE("B_layout", BLayout::name); + ADD_KEY_VALUE("C_layout", CLayout::name); + using TraitsADataType = DTypeTraits; + using TraitsBDataType = DTypeTraits; + using TraitsCDataType = DTypeTraits; + ADD_KEY_VALUE("A_type", TraitsADataType::name); + ADD_KEY_VALUE("B_type", TraitsBDataType::name); + ADD_KEY_VALUE("C_type", TraitsCDataType::name); + ADD_KEY_VALUE("structured_sparsity", GemmConfig::UseStructuredSparsity ? "on" : "off"); + + if constexpr(has_warp_tile_members::value) + { + ADD_KEY_VALUE("warp_tile", + std::to_string(GemmConfig::M_Warp_Tile) + "x" + + std::to_string(GemmConfig::N_Warp_Tile) + "x" + + std::to_string(GemmConfig::K_Warp_Tile)); + } + ADD_KEY_VALUE("persistent", persistent ? "on" : "off"); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec); + END_JSON_DUMP_FILE(); +} + +void dump_batched_gemm_json_results(const std::string& json_filename, + const std::string& op_name, + int M, + int N, + int K, + int stride_A, + int stride_B, + int stride_C, + int batch_stride_A, + int batch_stride_B, + int batch_stride_C, + int batch_count, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "batched_gemm_basic") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("op_name", op_name); + ADD_KEY_VALUE("M", M); + ADD_KEY_VALUE("N", N); + ADD_KEY_VALUE("K", K); + ADD_KEY_VALUE("stride_A", stride_A); + ADD_KEY_VALUE("stride_B", stride_B); + ADD_KEY_VALUE("stride_C", stride_C); + ADD_KEY_VALUE("batch_stride_A", batch_stride_A); + ADD_KEY_VALUE("batch_stride_B", batch_stride_B); + ADD_KEY_VALUE("batch_stride_C", batch_stride_C); + ADD_KEY_VALUE("batch_count", batch_count); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) + END_JSON_DUMP_FILE(); +} + +template +void dump_grouped_gemm_json_results(const std::string& json_filename, + const std::string& op_name, + int group_count, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "grouped_gemm") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("op_name", op_name); + ADD_KEY_VALUE("group_count", group_count); + ADD_KEY_VALUE("A_layout", ALayout::name); + ADD_KEY_VALUE("B_layout", BLayout::name); + ADD_KEY_VALUE("C_layout", CLayout::name); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) + END_JSON_DUMP_FILE(); +} + +void dump_flatmm_json_results(const std::string& json_filename, + const std::string& datatype, + int M, + int N, + int K, + int stride_A, + int stride_B, + int stride_C, + int kbatch, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "flatmm_basic") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("DataType", datatype); + ADD_KEY_VALUE("M", M); + ADD_KEY_VALUE("N", N); + ADD_KEY_VALUE("K", K); + ADD_KEY_VALUE("StrideA", stride_A); + ADD_KEY_VALUE("StrideB", stride_B); + ADD_KEY_VALUE("StrideC", stride_C); + ADD_KEY_VALUE("kbatch", kbatch); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) + END_JSON_DUMP_FILE(); +} + +void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename, + const std::string& op_name, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideD1, + int StrideE, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "gemm_multi_d_fp16") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("op_name", op_name); + ADD_KEY_VALUE("M", M); + ADD_KEY_VALUE("N", N); + ADD_KEY_VALUE("K", K); + ADD_KEY_VALUE("StrideA", StrideA); + ADD_KEY_VALUE("StrideB", StrideB); + ADD_KEY_VALUE("StrideD0", StrideD0); + ADD_KEY_VALUE("StrideD1", StrideD1); + ADD_KEY_VALUE("StrideE", StrideE); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) + END_JSON_DUMP_FILE(); +} + +void dump_elementwise_json_results(const std::string& json_filename, + const std::string& prec, + int grid_size, + int block_size, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "elementwise") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("prec", prec); + ADD_KEY_VALUE("grid_size", grid_size); + ADD_KEY_VALUE("block_size", block_size); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) + END_JSON_DUMP_FILE(); +} + +void dump_layernorm2d_fwd_json_results(const std::string& json_filename, + const std::string& prec_i, + const std::string& prec_o, + const std::string& prec_sm, + const std::string& prec_sy, + int m, + int n, + int x_stride, + int xr_stride, + int y_stride, + int yr_stride, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "layernorm2d_fwd") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("prec_i", prec_i); + ADD_KEY_VALUE("prec_o", prec_o); + ADD_KEY_VALUE("prec_sm", prec_sm); + ADD_KEY_VALUE("prec_sy", prec_sy); + ADD_KEY_VALUE("m", m); + ADD_KEY_VALUE("n", n); + ADD_KEY_VALUE("x_stride", x_stride); + ADD_KEY_VALUE("xr_stride", xr_stride); + ADD_KEY_VALUE("y_stride", y_stride); + ADD_KEY_VALUE("yr_stride", yr_stride); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) + END_JSON_DUMP_FILE(); +} + +template typename DTypeTraits> +void dump_reduce_json_results(const std::string& json_filename, + int N, + int C, + int H, + int W, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "reduce") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + using Traits = DTypeTraits; + ADD_KEY_VALUE("data_type", Traits::name); + ADD_KEY_VALUE("N", N); + ADD_KEY_VALUE("C", C); + ADD_KEY_VALUE("H", H); + ADD_KEY_VALUE("W", W); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) + END_JSON_DUMP_FILE(); +} + +void dump_permute_json_results(const std::string& json_filename, + const std::string& data_type, + bool pass, + float ave_time, + float tflop, + float gb_per_sec, + const std::string& kernel_name = "permute") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("data_type", data_type); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec) + END_JSON_DUMP_FILE(); +} + +void dump_topk_softmax_json(const std::string& json_filename, + const std::string& input_prec, + const std::string& weight_prec, + int tokens, + int experts, + int topk, + int stride_input, + int stride_output, + float ave_time, + float tflop, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "topk_softmax") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("input_prec", input_prec); + ADD_KEY_VALUE("weight_prec", weight_prec); + ADD_KEY_VALUE("tokens", tokens); + ADD_KEY_VALUE("experts", experts); + ADD_KEY_VALUE("topk", topk); + ADD_KEY_VALUE("stride_input", stride_input); + ADD_KEY_VALUE("stride_output", stride_output); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec); + END_JSON_DUMP_FILE(); +} + +void dump_rmsnorm2d_fwd_json(const std::string& json_filename, + const std::string& prec_str, + int m, + int n, + int x_stride, + int xr_stride, + int y_stride, + int yr_stride, + int use_model_sensitive_rmsnorm, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "rmsnorm2d_fwd") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("prec", prec_str); + ADD_KEY_VALUE("m", m); + ADD_KEY_VALUE("n", n); + ADD_KEY_VALUE("x_stride", x_stride); + ADD_KEY_VALUE("xr_stride", xr_stride); + ADD_KEY_VALUE("y_stride", y_stride); + ADD_KEY_VALUE("yr_stride", yr_stride); + ADD_KEY_VALUE("use_model_sensitive_rmsnorm", use_model_sensitive_rmsnorm); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec); + END_JSON_DUMP_FILE(); +} + +void dump_add_rmsnorm2d_rdquant_fwd_json( + const std::string& json_filename, + const std::string& input_data_type, + const std::string& quantized_data_type, + int m, + int n, + int stride, + float epsilon, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("input_data_type", input_data_type); + ADD_KEY_VALUE("quantized_data_type", quantized_data_type); + ADD_KEY_VALUE("m", m); + ADD_KEY_VALUE("n", n); + ADD_KEY_VALUE("stride", stride); + ADD_KEY_VALUE("epsilon", epsilon); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec); + END_JSON_DUMP_FILE(); +} + +void dump_smoothquant_json(const std::string& json_filename, + const std::string& prec_str, + int m, + int n, + int x_stride, + int y_stride, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "smoothquant") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("prec", prec_str); + ADD_KEY_VALUE("m", m); + ADD_KEY_VALUE("n", n); + ADD_KEY_VALUE("x_stride", x_stride); + ADD_KEY_VALUE("y_stride", y_stride); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec); + END_JSON_DUMP_FILE(); +} + +void dump_moe_sorting_json(const std::string& json_filename, + const std::string& index_prec, + const std::string& weight_prec, + const std::string& workspace_size, + int dispatch_policy, + int tokens, + int num_experts, + int topk, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "moe_sorting") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("index_prec", index_prec); + ADD_KEY_VALUE("weight_prec", weight_prec); + ADD_KEY_VALUE("workspace_size", workspace_size); + ADD_KEY_VALUE("dispatch_policy", dispatch_policy); + ADD_KEY_VALUE("tokens", tokens); + ADD_KEY_VALUE("num_experts", num_experts); + ADD_KEY_VALUE("topk", topk); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) + END_JSON_DUMP_FILE(); +} + +void dump_batched_transpose_json(const std::string& json_filename, + int N, + int C, + int H, + int W, + const std::string& layout_in, + const std::string& layout_out, + const std::string& prec, + float ave_time, + float tflops, + float gb_per_sec, + bool pass, + const std::string& kernel_name = "batched_transpose") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("N", N); + ADD_KEY_VALUE("C", C); + ADD_KEY_VALUE("H", H); + ADD_KEY_VALUE("W", W); + ADD_KEY_VALUE("LayoutIn", layout_in); + ADD_KEY_VALUE("LayoutOut", layout_out); + ADD_KEY_VALUE("Precision", prec); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) + END_JSON_DUMP_FILE(); +} + +void dump_moe_smoothquant_json(const std::string& json_filename, + const std::string& prec_i, + const std::string& prec_o, + int tokens, + int hidden_size, + int stride, + int experts, + int topk, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "moe_smoothquant") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("prec_i", prec_i); + ADD_KEY_VALUE("prec_o", prec_o); + ADD_KEY_VALUE("tokens", tokens); + ADD_KEY_VALUE("hidden_size", hidden_size); + ADD_KEY_VALUE("stride", stride); + ADD_KEY_VALUE("experts", experts); + ADD_KEY_VALUE("topk", topk); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) + END_JSON_DUMP_FILE(); +} + +void dump_fused_moe_json(const std::string& json_filename, + const std::string& api_str, + const std::string& prec_str, + int tokens, + bool is_local_token, + int local_tokens, + int experts, + int topk, + int hidden_size, + int intermediate_size, + int stride, + int block_m, + int activation, + bool gate_only, + bool fused_quant, + bool pass, + float ave_time, + float tflops, + float tb_per_sec, + const std::string& kernel_name = "fused_moe") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("api", api_str); + ADD_KEY_VALUE("prec", prec_str); + ADD_KEY_VALUE("tokens", tokens); + if(is_local_token) + { + ADD_KEY_VALUE("local_tokens", local_tokens); + } + ADD_KEY_VALUE("experts", experts); + ADD_KEY_VALUE("topk", topk); + ADD_KEY_VALUE("hidden_size", hidden_size); + ADD_KEY_VALUE("intermediate_size", intermediate_size); + ADD_KEY_VALUE("stride", stride); + ADD_KEY_VALUE("block_m", block_m); + ADD_KEY_VALUE("activation", activation); + ADD_KEY_VALUE("gate_only", gate_only); + ADD_KEY_VALUE("fused_quant", fused_quant); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, (tb_per_sec * 1024.0f)) + END_JSON_DUMP_FILE(); +} + +void dump_fmha_fwd_json_results(const std::string& json_filename, + const std::string& prec, + const std::string& mode, + const std::string& io_layout, + int batch, + int nhead, + int nhead_k, + int seqlen_qs, + int seqlen_ks, + int seqlen_kpads, + int hdim_q, + int hdim_v, + float scale_s, + float p_drop, + bool lse, + bool squant, + const std::string& bais, + const std::string& vlayout, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "fmha_fwd") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("prec", prec); + ADD_KEY_VALUE("mode", mode); + ADD_KEY_VALUE("io_layout", io_layout); + ADD_KEY_VALUE("batch", batch); + ADD_KEY_VALUE("nhead", nhead); + ADD_KEY_VALUE("nhead_k", nhead_k); + ADD_KEY_VALUE("seqlen_q", seqlen_qs); + ADD_KEY_VALUE("seqlen_k", seqlen_ks); + ADD_KEY_VALUE("seqlen_kpads", seqlen_kpads); + ADD_KEY_VALUE("hdim_q", hdim_q); + ADD_KEY_VALUE("hdim_v", hdim_v); + ADD_KEY_VALUE("scale_s", scale_s); + ADD_KEY_VALUE("p_drop", p_drop); + ADD_KEY_VALUE("lse", lse); + ADD_KEY_VALUE("squant", squant); + ADD_KEY_VALUE("bias", bais); + ADD_KEY_VALUE("vlayout", vlayout); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) + END_JSON_DUMP_FILE(); +} + +void dump_fmha_bwd_json_results(const std::string& json_filename, + const std::string& data_type, + const std::string& mode, + const std::string& i_perm, + const std::string& o_perm, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + float scale, + const std::string& bias, + bool use_dbias, + float p_drop, + bool s_randval, + bool deterministic, + const std::string& mask, + int mask_left, + int mask_right, + int workspace_size, + bool pass, + float ave_time, + float tflops, + float gb_per_sec, + const std::string& kernel_name = "fmha_bwd") +{ + START_JSON_DUMP_FILE(json_filename); + ADD_KEY_VALUE("name", kernel_name); + ADD_KEY_VALUE("prec", data_type); + ADD_KEY_VALUE("mode", mode); + ADD_KEY_VALUE("i_perm", i_perm); + ADD_KEY_VALUE("o_perm", o_perm); + ADD_KEY_VALUE("batch", batch); + ADD_KEY_VALUE("nhead", nhead); + ADD_KEY_VALUE("nhead_k", nhead_k); + ADD_KEY_VALUE("seqlen_q", seqlen_q); + ADD_KEY_VALUE("seqlen_k", seqlen_k); + ADD_KEY_VALUE("hdim_q", hdim_q); + ADD_KEY_VALUE("hdim_v", hdim_v); + ADD_KEY_VALUE("scale", scale); + ADD_KEY_VALUE("bias", bias); + ADD_KEY_VALUE("use_dbias", use_dbias); + ADD_KEY_VALUE("p_drop", p_drop); + ADD_KEY_VALUE("s_randval", s_randval); + ADD_KEY_VALUE("deterministic", deterministic ? "true" : "false"); + ADD_KEY_VALUE("mask", mask); + ADD_KEY_VALUE("mask_left", mask_left); + ADD_KEY_VALUE("mask_right", mask_right); + ADD_KEY_VALUE("workspace_size", workspace_size); + ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); + ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) + END_JSON_DUMP_FILE(); +} diff --git a/include/rapidjson/allocators.h b/include/rapidjson/allocators.h new file mode 100644 index 0000000000..275417bd8b --- /dev/null +++ b/include/rapidjson/allocators.h @@ -0,0 +1,693 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_ALLOCATORS_H_ +#define RAPIDJSON_ALLOCATORS_H_ + +#include "rapidjson.h" +#include "internal/meta.h" + +#include +#include + +#if RAPIDJSON_HAS_CXX11 +#include +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +/////////////////////////////////////////////////////////////////////////////// +// Allocator + +/*! \class rapidjson::Allocator + \brief Concept for allocating, resizing and freeing memory block. + + Note that Malloc() and Realloc() are non-static but Free() is static. + + So if an allocator need to support Free(), it needs to put its pointer in + the header of memory block. + +\code +concept Allocator { + static const bool kNeedFree; //!< Whether this allocator needs to call Free(). + + // Allocate a memory block. + // \param size of the memory block in bytes. + // \returns pointer to the memory block. + void* Malloc(size_t size); + + // Resize a memory block. + // \param originalPtr The pointer to current memory block. Null pointer is permitted. + // \param originalSize The current size in bytes. (Design issue: since some allocator may not book-keep this, explicitly pass to it can save memory.) + // \param newSize the new size in bytes. + void* Realloc(void* originalPtr, size_t originalSize, size_t newSize); + + // Free a memory block. + // \param pointer to the memory block. Null pointer is permitted. + static void Free(void *ptr); +}; +\endcode +*/ + + +/*! \def RAPIDJSON_ALLOCATOR_DEFAULT_CHUNK_CAPACITY + \ingroup RAPIDJSON_CONFIG + \brief User-defined kDefaultChunkCapacity definition. + + User can define this as any \c size that is a power of 2. +*/ + +#ifndef RAPIDJSON_ALLOCATOR_DEFAULT_CHUNK_CAPACITY +#define RAPIDJSON_ALLOCATOR_DEFAULT_CHUNK_CAPACITY (64 * 1024) +#endif + + +/////////////////////////////////////////////////////////////////////////////// +// CrtAllocator + +//! C-runtime library allocator. +/*! This class is just wrapper for standard C library memory routines. + \note implements Allocator concept +*/ +class CrtAllocator { +public: + static const bool kNeedFree = true; + void* Malloc(size_t size) { + if (size) // behavior of malloc(0) is implementation defined. + return RAPIDJSON_MALLOC(size); + else + return NULL; // standardize to returning NULL. + } + void* Realloc(void* originalPtr, size_t originalSize, size_t newSize) { + (void)originalSize; + if (newSize == 0) { + RAPIDJSON_FREE(originalPtr); + return NULL; + } + return RAPIDJSON_REALLOC(originalPtr, newSize); + } + static void Free(void *ptr) RAPIDJSON_NOEXCEPT { RAPIDJSON_FREE(ptr); } + + bool operator==(const CrtAllocator&) const RAPIDJSON_NOEXCEPT { + return true; + } + bool operator!=(const CrtAllocator&) const RAPIDJSON_NOEXCEPT { + return false; + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// MemoryPoolAllocator + +//! Default memory allocator used by the parser and DOM. +/*! This allocator allocate memory blocks from pre-allocated memory chunks. + + It does not free memory blocks. And Realloc() only allocate new memory. + + The memory chunks are allocated by BaseAllocator, which is CrtAllocator by default. + + User may also supply a buffer as the first chunk. + + If the user-buffer is full then additional chunks are allocated by BaseAllocator. + + The user-buffer is not deallocated by this allocator. + + \tparam BaseAllocator the allocator type for allocating memory chunks. Default is CrtAllocator. + \note implements Allocator concept +*/ +template +class MemoryPoolAllocator { + //! Chunk header for perpending to each chunk. + /*! Chunks are stored as a singly linked list. + */ + struct ChunkHeader { + size_t capacity; //!< Capacity of the chunk in bytes (excluding the header itself). + size_t size; //!< Current size of allocated memory in bytes. + ChunkHeader *next; //!< Next chunk in the linked list. + }; + + struct SharedData { + ChunkHeader *chunkHead; //!< Head of the chunk linked-list. Only the head chunk serves allocation. + BaseAllocator* ownBaseAllocator; //!< base allocator created by this object. + size_t refcount; + bool ownBuffer; + }; + + static const size_t SIZEOF_SHARED_DATA = RAPIDJSON_ALIGN(sizeof(SharedData)); + static const size_t SIZEOF_CHUNK_HEADER = RAPIDJSON_ALIGN(sizeof(ChunkHeader)); + + static inline ChunkHeader *GetChunkHead(SharedData *shared) + { + return reinterpret_cast(reinterpret_cast(shared) + SIZEOF_SHARED_DATA); + } + static inline uint8_t *GetChunkBuffer(SharedData *shared) + { + return reinterpret_cast(shared->chunkHead) + SIZEOF_CHUNK_HEADER; + } + + static const size_t kDefaultChunkCapacity = RAPIDJSON_ALLOCATOR_DEFAULT_CHUNK_CAPACITY; //!< Default chunk capacity. + +public: + static const bool kNeedFree = false; //!< Tell users that no need to call Free() with this allocator. (concept Allocator) + static const bool kRefCounted = true; //!< Tell users that this allocator is reference counted on copy + + //! Constructor with chunkSize. + /*! \param chunkSize The size of memory chunk. The default is kDefaultChunkSize. + \param baseAllocator The allocator for allocating memory chunks. + */ + explicit + MemoryPoolAllocator(size_t chunkSize = kDefaultChunkCapacity, BaseAllocator* baseAllocator = 0) : + chunk_capacity_(chunkSize), + baseAllocator_(baseAllocator ? baseAllocator : RAPIDJSON_NEW(BaseAllocator)()), + shared_(static_cast(baseAllocator_ ? baseAllocator_->Malloc(SIZEOF_SHARED_DATA + SIZEOF_CHUNK_HEADER) : 0)) + { + RAPIDJSON_ASSERT(baseAllocator_ != 0); + RAPIDJSON_ASSERT(shared_ != 0); + if (baseAllocator) { + shared_->ownBaseAllocator = 0; + } + else { + shared_->ownBaseAllocator = baseAllocator_; + } + shared_->chunkHead = GetChunkHead(shared_); + shared_->chunkHead->capacity = 0; + shared_->chunkHead->size = 0; + shared_->chunkHead->next = 0; + shared_->ownBuffer = true; + shared_->refcount = 1; + } + + //! Constructor with user-supplied buffer. + /*! The user buffer will be used firstly. When it is full, memory pool allocates new chunk with chunk size. + + The user buffer will not be deallocated when this allocator is destructed. + + \param buffer User supplied buffer. + \param size Size of the buffer in bytes. It must at least larger than sizeof(ChunkHeader). + \param chunkSize The size of memory chunk. The default is kDefaultChunkSize. + \param baseAllocator The allocator for allocating memory chunks. + */ + MemoryPoolAllocator(void *buffer, size_t size, size_t chunkSize = kDefaultChunkCapacity, BaseAllocator* baseAllocator = 0) : + chunk_capacity_(chunkSize), + baseAllocator_(baseAllocator), + shared_(static_cast(AlignBuffer(buffer, size))) + { + RAPIDJSON_ASSERT(size >= SIZEOF_SHARED_DATA + SIZEOF_CHUNK_HEADER); + shared_->chunkHead = GetChunkHead(shared_); + shared_->chunkHead->capacity = size - SIZEOF_SHARED_DATA - SIZEOF_CHUNK_HEADER; + shared_->chunkHead->size = 0; + shared_->chunkHead->next = 0; + shared_->ownBaseAllocator = 0; + shared_->ownBuffer = false; + shared_->refcount = 1; + } + + MemoryPoolAllocator(const MemoryPoolAllocator& rhs) RAPIDJSON_NOEXCEPT : + chunk_capacity_(rhs.chunk_capacity_), + baseAllocator_(rhs.baseAllocator_), + shared_(rhs.shared_) + { + RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0); + ++shared_->refcount; + } + MemoryPoolAllocator& operator=(const MemoryPoolAllocator& rhs) RAPIDJSON_NOEXCEPT + { + RAPIDJSON_NOEXCEPT_ASSERT(rhs.shared_->refcount > 0); + ++rhs.shared_->refcount; + this->~MemoryPoolAllocator(); + baseAllocator_ = rhs.baseAllocator_; + chunk_capacity_ = rhs.chunk_capacity_; + shared_ = rhs.shared_; + return *this; + } + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + MemoryPoolAllocator(MemoryPoolAllocator&& rhs) RAPIDJSON_NOEXCEPT : + chunk_capacity_(rhs.chunk_capacity_), + baseAllocator_(rhs.baseAllocator_), + shared_(rhs.shared_) + { + RAPIDJSON_NOEXCEPT_ASSERT(rhs.shared_->refcount > 0); + rhs.shared_ = 0; + } + MemoryPoolAllocator& operator=(MemoryPoolAllocator&& rhs) RAPIDJSON_NOEXCEPT + { + RAPIDJSON_NOEXCEPT_ASSERT(rhs.shared_->refcount > 0); + this->~MemoryPoolAllocator(); + baseAllocator_ = rhs.baseAllocator_; + chunk_capacity_ = rhs.chunk_capacity_; + shared_ = rhs.shared_; + rhs.shared_ = 0; + return *this; + } +#endif + + //! Destructor. + /*! This deallocates all memory chunks, excluding the user-supplied buffer. + */ + ~MemoryPoolAllocator() RAPIDJSON_NOEXCEPT { + if (!shared_) { + // do nothing if moved + return; + } + if (shared_->refcount > 1) { + --shared_->refcount; + return; + } + Clear(); + BaseAllocator *a = shared_->ownBaseAllocator; + if (shared_->ownBuffer) { + baseAllocator_->Free(shared_); + } + RAPIDJSON_DELETE(a); + } + + //! Deallocates all memory chunks, excluding the first/user one. + void Clear() RAPIDJSON_NOEXCEPT { + RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0); + for (;;) { + ChunkHeader* c = shared_->chunkHead; + if (!c->next) { + break; + } + shared_->chunkHead = c->next; + baseAllocator_->Free(c); + } + shared_->chunkHead->size = 0; + } + + //! Computes the total capacity of allocated memory chunks. + /*! \return total capacity in bytes. + */ + size_t Capacity() const RAPIDJSON_NOEXCEPT { + RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0); + size_t capacity = 0; + for (ChunkHeader* c = shared_->chunkHead; c != 0; c = c->next) + capacity += c->capacity; + return capacity; + } + + //! Computes the memory blocks allocated. + /*! \return total used bytes. + */ + size_t Size() const RAPIDJSON_NOEXCEPT { + RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0); + size_t size = 0; + for (ChunkHeader* c = shared_->chunkHead; c != 0; c = c->next) + size += c->size; + return size; + } + + //! Whether the allocator is shared. + /*! \return true or false. + */ + bool Shared() const RAPIDJSON_NOEXCEPT { + RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0); + return shared_->refcount > 1; + } + + //! Allocates a memory block. (concept Allocator) + void* Malloc(size_t size) { + RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0); + if (!size) + return NULL; + + size = RAPIDJSON_ALIGN(size); + if (RAPIDJSON_UNLIKELY(shared_->chunkHead->size + size > shared_->chunkHead->capacity)) + if (!AddChunk(chunk_capacity_ > size ? chunk_capacity_ : size)) + return NULL; + + void *buffer = GetChunkBuffer(shared_) + shared_->chunkHead->size; + shared_->chunkHead->size += size; + return buffer; + } + + //! Resizes a memory block (concept Allocator) + void* Realloc(void* originalPtr, size_t originalSize, size_t newSize) { + if (originalPtr == 0) + return Malloc(newSize); + + RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0); + if (newSize == 0) + return NULL; + + originalSize = RAPIDJSON_ALIGN(originalSize); + newSize = RAPIDJSON_ALIGN(newSize); + + // Do not shrink if new size is smaller than original + if (originalSize >= newSize) + return originalPtr; + + // Simply expand it if it is the last allocation and there is sufficient space + if (originalPtr == GetChunkBuffer(shared_) + shared_->chunkHead->size - originalSize) { + size_t increment = static_cast(newSize - originalSize); + if (shared_->chunkHead->size + increment <= shared_->chunkHead->capacity) { + shared_->chunkHead->size += increment; + return originalPtr; + } + } + + // Realloc process: allocate and copy memory, do not free original buffer. + if (void* newBuffer = Malloc(newSize)) { + if (originalSize) + std::memcpy(newBuffer, originalPtr, originalSize); + return newBuffer; + } + else + return NULL; + } + + //! Frees a memory block (concept Allocator) + static void Free(void *ptr) RAPIDJSON_NOEXCEPT { (void)ptr; } // Do nothing + + //! Compare (equality) with another MemoryPoolAllocator + bool operator==(const MemoryPoolAllocator& rhs) const RAPIDJSON_NOEXCEPT { + RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0); + RAPIDJSON_NOEXCEPT_ASSERT(rhs.shared_->refcount > 0); + return shared_ == rhs.shared_; + } + //! Compare (inequality) with another MemoryPoolAllocator + bool operator!=(const MemoryPoolAllocator& rhs) const RAPIDJSON_NOEXCEPT { + return !operator==(rhs); + } + +private: + //! Creates a new chunk. + /*! \param capacity Capacity of the chunk in bytes. + \return true if success. + */ + bool AddChunk(size_t capacity) { + if (!baseAllocator_) + shared_->ownBaseAllocator = baseAllocator_ = RAPIDJSON_NEW(BaseAllocator)(); + if (ChunkHeader* chunk = static_cast(baseAllocator_->Malloc(SIZEOF_CHUNK_HEADER + capacity))) { + chunk->capacity = capacity; + chunk->size = 0; + chunk->next = shared_->chunkHead; + shared_->chunkHead = chunk; + return true; + } + else + return false; + } + + static inline void* AlignBuffer(void* buf, size_t &size) + { + RAPIDJSON_NOEXCEPT_ASSERT(buf != 0); + const uintptr_t mask = sizeof(void*) - 1; + const uintptr_t ubuf = reinterpret_cast(buf); + if (RAPIDJSON_UNLIKELY(ubuf & mask)) { + const uintptr_t abuf = (ubuf + mask) & ~mask; + RAPIDJSON_ASSERT(size >= abuf - ubuf); + buf = reinterpret_cast(abuf); + size -= abuf - ubuf; + } + return buf; + } + + size_t chunk_capacity_; //!< The minimum capacity of chunk when they are allocated. + BaseAllocator* baseAllocator_; //!< base allocator for allocating memory chunks. + SharedData *shared_; //!< The shared data of the allocator +}; + +namespace internal { + template + struct IsRefCounted : + public FalseType + { }; + template + struct IsRefCounted::Type> : + public TrueType + { }; +} + +template +inline T* Realloc(A& a, T* old_p, size_t old_n, size_t new_n) +{ + RAPIDJSON_NOEXCEPT_ASSERT(old_n <= (std::numeric_limits::max)() / sizeof(T) && new_n <= (std::numeric_limits::max)() / sizeof(T)); + return static_cast(a.Realloc(old_p, old_n * sizeof(T), new_n * sizeof(T))); +} + +template +inline T *Malloc(A& a, size_t n = 1) +{ + return Realloc(a, NULL, 0, n); +} + +template +inline void Free(A& a, T *p, size_t n = 1) +{ + static_cast(Realloc(a, p, n, 0)); +} + +#ifdef __GNUC__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(effc++) // std::allocator can safely be inherited +#endif + +template +class StdAllocator : + public std::allocator +{ + typedef std::allocator allocator_type; +#if RAPIDJSON_HAS_CXX11 + typedef std::allocator_traits traits_type; +#else + typedef allocator_type traits_type; +#endif + +public: + typedef BaseAllocator BaseAllocatorType; + + StdAllocator() RAPIDJSON_NOEXCEPT : + allocator_type(), + baseAllocator_() + { } + + StdAllocator(const StdAllocator& rhs) RAPIDJSON_NOEXCEPT : + allocator_type(rhs), + baseAllocator_(rhs.baseAllocator_) + { } + + template + StdAllocator(const StdAllocator& rhs) RAPIDJSON_NOEXCEPT : + allocator_type(rhs), + baseAllocator_(rhs.baseAllocator_) + { } + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + StdAllocator(StdAllocator&& rhs) RAPIDJSON_NOEXCEPT : + allocator_type(std::move(rhs)), + baseAllocator_(std::move(rhs.baseAllocator_)) + { } +#endif +#if RAPIDJSON_HAS_CXX11 + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; +#endif + + /* implicit */ + StdAllocator(const BaseAllocator& baseAllocator) RAPIDJSON_NOEXCEPT : + allocator_type(), + baseAllocator_(baseAllocator) + { } + + ~StdAllocator() RAPIDJSON_NOEXCEPT + { } + + template + struct rebind { + typedef StdAllocator other; + }; + + typedef typename traits_type::size_type size_type; + typedef typename traits_type::difference_type difference_type; + + typedef typename traits_type::value_type value_type; + typedef typename traits_type::pointer pointer; + typedef typename traits_type::const_pointer const_pointer; + +#if RAPIDJSON_HAS_CXX11 + + typedef typename std::add_lvalue_reference::type &reference; + typedef typename std::add_lvalue_reference::type>::type &const_reference; + + pointer address(reference r) const RAPIDJSON_NOEXCEPT + { + return std::addressof(r); + } + const_pointer address(const_reference r) const RAPIDJSON_NOEXCEPT + { + return std::addressof(r); + } + + size_type max_size() const RAPIDJSON_NOEXCEPT + { + return traits_type::max_size(*this); + } + + template + void construct(pointer p, Args&&... args) + { + traits_type::construct(*this, p, std::forward(args)...); + } + void destroy(pointer p) + { + traits_type::destroy(*this, p); + } + +#else // !RAPIDJSON_HAS_CXX11 + + typedef typename allocator_type::reference reference; + typedef typename allocator_type::const_reference const_reference; + + pointer address(reference r) const RAPIDJSON_NOEXCEPT + { + return allocator_type::address(r); + } + const_pointer address(const_reference r) const RAPIDJSON_NOEXCEPT + { + return allocator_type::address(r); + } + + size_type max_size() const RAPIDJSON_NOEXCEPT + { + return allocator_type::max_size(); + } + + void construct(pointer p, const_reference r) + { + allocator_type::construct(p, r); + } + void destroy(pointer p) + { + allocator_type::destroy(p); + } + +#endif // !RAPIDJSON_HAS_CXX11 + + template + U* allocate(size_type n = 1, const void* = 0) + { + return RAPIDJSON_NAMESPACE::Malloc(baseAllocator_, n); + } + template + void deallocate(U* p, size_type n = 1) + { + RAPIDJSON_NAMESPACE::Free(baseAllocator_, p, n); + } + + pointer allocate(size_type n = 1, const void* = 0) + { + return allocate(n); + } + void deallocate(pointer p, size_type n = 1) + { + deallocate(p, n); + } + +#if RAPIDJSON_HAS_CXX11 + using is_always_equal = std::is_empty; +#endif + + template + bool operator==(const StdAllocator& rhs) const RAPIDJSON_NOEXCEPT + { + return baseAllocator_ == rhs.baseAllocator_; + } + template + bool operator!=(const StdAllocator& rhs) const RAPIDJSON_NOEXCEPT + { + return !operator==(rhs); + } + + //! rapidjson Allocator concept + static const bool kNeedFree = BaseAllocator::kNeedFree; + static const bool kRefCounted = internal::IsRefCounted::Value; + void* Malloc(size_t size) + { + return baseAllocator_.Malloc(size); + } + void* Realloc(void* originalPtr, size_t originalSize, size_t newSize) + { + return baseAllocator_.Realloc(originalPtr, originalSize, newSize); + } + static void Free(void *ptr) RAPIDJSON_NOEXCEPT + { + BaseAllocator::Free(ptr); + } + +private: + template + friend class StdAllocator; // access to StdAllocator.* + + BaseAllocator baseAllocator_; +}; + +#if !RAPIDJSON_HAS_CXX17 // std::allocator deprecated in C++17 +template +class StdAllocator : + public std::allocator +{ + typedef std::allocator allocator_type; + +public: + typedef BaseAllocator BaseAllocatorType; + + StdAllocator() RAPIDJSON_NOEXCEPT : + allocator_type(), + baseAllocator_() + { } + + StdAllocator(const StdAllocator& rhs) RAPIDJSON_NOEXCEPT : + allocator_type(rhs), + baseAllocator_(rhs.baseAllocator_) + { } + + template + StdAllocator(const StdAllocator& rhs) RAPIDJSON_NOEXCEPT : + allocator_type(rhs), + baseAllocator_(rhs.baseAllocator_) + { } + + /* implicit */ + StdAllocator(const BaseAllocator& baseAllocator) RAPIDJSON_NOEXCEPT : + allocator_type(), + baseAllocator_(baseAllocator) + { } + + ~StdAllocator() RAPIDJSON_NOEXCEPT + { } + + template + struct rebind { + typedef StdAllocator other; + }; + + typedef typename allocator_type::value_type value_type; + +private: + template + friend class StdAllocator; // access to StdAllocator.* + + BaseAllocator baseAllocator_; +}; +#endif + +#ifdef __GNUC__ +RAPIDJSON_DIAG_POP +#endif + +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_ENCODINGS_H_ diff --git a/include/rapidjson/cursorstreamwrapper.h b/include/rapidjson/cursorstreamwrapper.h new file mode 100644 index 0000000000..fd6513db14 --- /dev/null +++ b/include/rapidjson/cursorstreamwrapper.h @@ -0,0 +1,78 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_CURSORSTREAMWRAPPER_H_ +#define RAPIDJSON_CURSORSTREAMWRAPPER_H_ + +#include "stream.h" + +#if defined(__GNUC__) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(effc++) +#endif + +#if defined(_MSC_VER) && _MSC_VER <= 1800 +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(4702) // unreachable code +RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated +#endif + +RAPIDJSON_NAMESPACE_BEGIN + + +//! Cursor stream wrapper for counting line and column number if error exists. +/*! + \tparam InputStream Any stream that implements Stream Concept +*/ +template > +class CursorStreamWrapper : public GenericStreamWrapper { +public: + typedef typename Encoding::Ch Ch; + + CursorStreamWrapper(InputStream& is): + GenericStreamWrapper(is), line_(1), col_(0) {} + + // counting line and column number + Ch Take() { + Ch ch = this->is_.Take(); + if(ch == '\n') { + line_ ++; + col_ = 0; + } else { + col_ ++; + } + return ch; + } + + //! Get the error line number, if error exists. + size_t GetLine() const { return line_; } + //! Get the error column number, if error exists. + size_t GetColumn() const { return col_; } + +private: + size_t line_; //!< Current Line + size_t col_; //!< Current Column +}; + +#if defined(_MSC_VER) && _MSC_VER <= 1800 +RAPIDJSON_DIAG_POP +#endif + +#if defined(__GNUC__) +RAPIDJSON_DIAG_POP +#endif + +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_CURSORSTREAMWRAPPER_H_ diff --git a/include/rapidjson/document.h b/include/rapidjson/document.h new file mode 100644 index 0000000000..4b2d723224 --- /dev/null +++ b/include/rapidjson/document.h @@ -0,0 +1,3044 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_DOCUMENT_H_ +#define RAPIDJSON_DOCUMENT_H_ + +/*! \file document.h */ + +#include "reader.h" +#include "internal/meta.h" +#include "internal/strfunc.h" +#include "memorystream.h" +#include "encodedstream.h" +#include // placement new +#include +#ifdef __cpp_lib_three_way_comparison +#include +#endif + +RAPIDJSON_DIAG_PUSH +#ifdef __clang__ +RAPIDJSON_DIAG_OFF(padded) +RAPIDJSON_DIAG_OFF(switch-enum) +RAPIDJSON_DIAG_OFF(c++98-compat) +#elif defined(_MSC_VER) +RAPIDJSON_DIAG_OFF(4127) // conditional expression is constant +RAPIDJSON_DIAG_OFF(4244) // conversion from kXxxFlags to 'uint16_t', possible loss of data +#endif + +#ifdef __GNUC__ +RAPIDJSON_DIAG_OFF(effc++) +#endif // __GNUC__ + +#ifdef GetObject +// see https://github.com/Tencent/rapidjson/issues/1448 +// a former included windows.h might have defined a macro called GetObject, which affects +// GetObject defined here. This ensures the macro does not get applied +#pragma push_macro("GetObject") +#define RAPIDJSON_WINDOWS_GETOBJECT_WORKAROUND_APPLIED +#undef GetObject +#endif + +#ifndef RAPIDJSON_NOMEMBERITERATORCLASS +#include // std::random_access_iterator_tag +#endif + +#if RAPIDJSON_USE_MEMBERSMAP +#include // std::multimap +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +// Forward declaration. +template +class GenericValue; + +template +class GenericDocument; + +/*! \def RAPIDJSON_DEFAULT_ALLOCATOR + \ingroup RAPIDJSON_CONFIG + \brief Allows to choose default allocator. + + User can define this to use CrtAllocator or MemoryPoolAllocator. +*/ +#ifndef RAPIDJSON_DEFAULT_ALLOCATOR +#define RAPIDJSON_DEFAULT_ALLOCATOR ::RAPIDJSON_NAMESPACE::MemoryPoolAllocator<::RAPIDJSON_NAMESPACE::CrtAllocator> +#endif + +/*! \def RAPIDJSON_DEFAULT_STACK_ALLOCATOR + \ingroup RAPIDJSON_CONFIG + \brief Allows to choose default stack allocator for Document. + + User can define this to use CrtAllocator or MemoryPoolAllocator. +*/ +#ifndef RAPIDJSON_DEFAULT_STACK_ALLOCATOR +#define RAPIDJSON_DEFAULT_STACK_ALLOCATOR ::RAPIDJSON_NAMESPACE::CrtAllocator +#endif + +/*! \def RAPIDJSON_VALUE_DEFAULT_OBJECT_CAPACITY + \ingroup RAPIDJSON_CONFIG + \brief User defined kDefaultObjectCapacity value. + + User can define this as any natural number. +*/ +#ifndef RAPIDJSON_VALUE_DEFAULT_OBJECT_CAPACITY +// number of objects that rapidjson::Value allocates memory for by default +#define RAPIDJSON_VALUE_DEFAULT_OBJECT_CAPACITY 16 +#endif + +/*! \def RAPIDJSON_VALUE_DEFAULT_ARRAY_CAPACITY + \ingroup RAPIDJSON_CONFIG + \brief User defined kDefaultArrayCapacity value. + + User can define this as any natural number. +*/ +#ifndef RAPIDJSON_VALUE_DEFAULT_ARRAY_CAPACITY +// number of array elements that rapidjson::Value allocates memory for by default +#define RAPIDJSON_VALUE_DEFAULT_ARRAY_CAPACITY 16 +#endif + +//! Name-value pair in a JSON object value. +/*! + This class was internal to GenericValue. It used to be a inner struct. + But a compiler (IBM XL C/C++ for AIX) have reported to have problem with that so it moved as a namespace scope struct. + https://code.google.com/p/rapidjson/issues/detail?id=64 +*/ +template +class GenericMember { +public: + GenericValue name; //!< name of member (must be a string) + GenericValue value; //!< value of member. + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + //! Move constructor in C++11 + GenericMember(GenericMember&& rhs) RAPIDJSON_NOEXCEPT + : name(std::move(rhs.name)), + value(std::move(rhs.value)) + { + } + + //! Move assignment in C++11 + GenericMember& operator=(GenericMember&& rhs) RAPIDJSON_NOEXCEPT { + return *this = static_cast(rhs); + } +#endif + + //! Assignment with move semantics. + /*! \param rhs Source of the assignment. Its name and value will become a null value after assignment. + */ + GenericMember& operator=(GenericMember& rhs) RAPIDJSON_NOEXCEPT { + if (RAPIDJSON_LIKELY(this != &rhs)) { + name = rhs.name; + value = rhs.value; + } + return *this; + } + + // swap() for std::sort() and other potential use in STL. + friend inline void swap(GenericMember& a, GenericMember& b) RAPIDJSON_NOEXCEPT { + a.name.Swap(b.name); + a.value.Swap(b.value); + } + +private: + //! Copy constructor is not permitted. + GenericMember(const GenericMember& rhs); +}; + +/////////////////////////////////////////////////////////////////////////////// +// GenericMemberIterator + +#ifndef RAPIDJSON_NOMEMBERITERATORCLASS + +//! (Constant) member iterator for a JSON object value +/*! + \tparam Const Is this a constant iterator? + \tparam Encoding Encoding of the value. (Even non-string values need to have the same encoding in a document) + \tparam Allocator Allocator type for allocating memory of object, array and string. + + This class implements a Random Access Iterator for GenericMember elements + of a GenericValue, see ISO/IEC 14882:2003(E) C++ standard, 24.1 [lib.iterator.requirements]. + + \note This iterator implementation is mainly intended to avoid implicit + conversions from iterator values to \c NULL, + e.g. from GenericValue::FindMember. + + \note Define \c RAPIDJSON_NOMEMBERITERATORCLASS to fall back to a + pointer-based implementation, if your platform doesn't provide + the C++ header. + + \see GenericMember, GenericValue::MemberIterator, GenericValue::ConstMemberIterator + */ +template +class GenericMemberIterator { + + friend class GenericValue; + template friend class GenericMemberIterator; + + typedef GenericMember PlainType; + typedef typename internal::MaybeAddConst::Type ValueType; + +public: + //! Iterator type itself + typedef GenericMemberIterator Iterator; + //! Constant iterator type + typedef GenericMemberIterator ConstIterator; + //! Non-constant iterator type + typedef GenericMemberIterator NonConstIterator; + + /** \name std::iterator_traits support */ + //@{ + typedef ValueType value_type; + typedef ValueType * pointer; + typedef ValueType & reference; + typedef std::ptrdiff_t difference_type; + typedef std::random_access_iterator_tag iterator_category; + //@} + + //! Pointer to (const) GenericMember + typedef pointer Pointer; + //! Reference to (const) GenericMember + typedef reference Reference; + //! Signed integer type (e.g. \c ptrdiff_t) + typedef difference_type DifferenceType; + + //! Default constructor (singular value) + /*! Creates an iterator pointing to no element. + \note All operations, except for comparisons, are undefined on such values. + */ + GenericMemberIterator() : ptr_() {} + + //! Iterator conversions to more const + /*! + \param it (Non-const) iterator to copy from + + Allows the creation of an iterator from another GenericMemberIterator + that is "less const". Especially, creating a non-constant iterator + from a constant iterator are disabled: + \li const -> non-const (not ok) + \li const -> const (ok) + \li non-const -> const (ok) + \li non-const -> non-const (ok) + + \note If the \c Const template parameter is already \c false, this + constructor effectively defines a regular copy-constructor. + Otherwise, the copy constructor is implicitly defined. + */ + GenericMemberIterator(const NonConstIterator & it) : ptr_(it.ptr_) {} + Iterator& operator=(const NonConstIterator & it) { ptr_ = it.ptr_; return *this; } + + //! @name stepping + //@{ + Iterator& operator++(){ ++ptr_; return *this; } + Iterator& operator--(){ --ptr_; return *this; } + Iterator operator++(int){ Iterator old(*this); ++ptr_; return old; } + Iterator operator--(int){ Iterator old(*this); --ptr_; return old; } + //@} + + //! @name increment/decrement + //@{ + Iterator operator+(DifferenceType n) const { return Iterator(ptr_+n); } + Iterator operator-(DifferenceType n) const { return Iterator(ptr_-n); } + + Iterator& operator+=(DifferenceType n) { ptr_+=n; return *this; } + Iterator& operator-=(DifferenceType n) { ptr_-=n; return *this; } + //@} + + //! @name relations + //@{ + template bool operator==(const GenericMemberIterator& that) const { return ptr_ == that.ptr_; } + template bool operator!=(const GenericMemberIterator& that) const { return ptr_ != that.ptr_; } + template bool operator<=(const GenericMemberIterator& that) const { return ptr_ <= that.ptr_; } + template bool operator>=(const GenericMemberIterator& that) const { return ptr_ >= that.ptr_; } + template bool operator< (const GenericMemberIterator& that) const { return ptr_ < that.ptr_; } + template bool operator> (const GenericMemberIterator& that) const { return ptr_ > that.ptr_; } + +#ifdef __cpp_lib_three_way_comparison + template std::strong_ordering operator<=>(const GenericMemberIterator& that) const { return ptr_ <=> that.ptr_; } +#endif + //@} + + //! @name dereference + //@{ + Reference operator*() const { return *ptr_; } + Pointer operator->() const { return ptr_; } + Reference operator[](DifferenceType n) const { return ptr_[n]; } + //@} + + //! Distance + DifferenceType operator-(ConstIterator that) const { return ptr_-that.ptr_; } + +private: + //! Internal constructor from plain pointer + explicit GenericMemberIterator(Pointer p) : ptr_(p) {} + + Pointer ptr_; //!< raw pointer +}; + +#else // RAPIDJSON_NOMEMBERITERATORCLASS + +// class-based member iterator implementation disabled, use plain pointers + +template +class GenericMemberIterator; + +//! non-const GenericMemberIterator +template +class GenericMemberIterator { +public: + //! use plain pointer as iterator type + typedef GenericMember* Iterator; +}; +//! const GenericMemberIterator +template +class GenericMemberIterator { +public: + //! use plain const pointer as iterator type + typedef const GenericMember* Iterator; +}; + +#endif // RAPIDJSON_NOMEMBERITERATORCLASS + +/////////////////////////////////////////////////////////////////////////////// +// GenericStringRef + +//! Reference to a constant string (not taking a copy) +/*! + \tparam CharType character type of the string + + This helper class is used to automatically infer constant string + references for string literals, especially from \c const \b (!) + character arrays. + + The main use is for creating JSON string values without copying the + source string via an \ref Allocator. This requires that the referenced + string pointers have a sufficient lifetime, which exceeds the lifetime + of the associated GenericValue. + + \b Example + \code + Value v("foo"); // ok, no need to copy & calculate length + const char foo[] = "foo"; + v.SetString(foo); // ok + + const char* bar = foo; + // Value x(bar); // not ok, can't rely on bar's lifetime + Value x(StringRef(bar)); // lifetime explicitly guaranteed by user + Value y(StringRef(bar, 3)); // ok, explicitly pass length + \endcode + + \see StringRef, GenericValue::SetString +*/ +template +struct GenericStringRef { + typedef CharType Ch; //!< character type of the string + + //! Create string reference from \c const character array +#ifndef __clang__ // -Wdocumentation + /*! + This constructor implicitly creates a constant string reference from + a \c const character array. It has better performance than + \ref StringRef(const CharType*) by inferring the string \ref length + from the array length, and also supports strings containing null + characters. + + \tparam N length of the string, automatically inferred + + \param str Constant character array, lifetime assumed to be longer + than the use of the string in e.g. a GenericValue + + \post \ref s == str + + \note Constant complexity. + \note There is a hidden, private overload to disallow references to + non-const character arrays to be created via this constructor. + By this, e.g. function-scope arrays used to be filled via + \c snprintf are excluded from consideration. + In such cases, the referenced string should be \b copied to the + GenericValue instead. + */ +#endif + template + GenericStringRef(const CharType (&str)[N]) RAPIDJSON_NOEXCEPT + : s(str), length(N-1) {} + + //! Explicitly create string reference from \c const character pointer +#ifndef __clang__ // -Wdocumentation + /*! + This constructor can be used to \b explicitly create a reference to + a constant string pointer. + + \see StringRef(const CharType*) + + \param str Constant character pointer, lifetime assumed to be longer + than the use of the string in e.g. a GenericValue + + \post \ref s == str + + \note There is a hidden, private overload to disallow references to + non-const character arrays to be created via this constructor. + By this, e.g. function-scope arrays used to be filled via + \c snprintf are excluded from consideration. + In such cases, the referenced string should be \b copied to the + GenericValue instead. + */ +#endif + explicit GenericStringRef(const CharType* str) + : s(str), length(NotNullStrLen(str)) {} + + //! Create constant string reference from pointer and length +#ifndef __clang__ // -Wdocumentation + /*! \param str constant string, lifetime assumed to be longer than the use of the string in e.g. a GenericValue + \param len length of the string, excluding the trailing NULL terminator + + \post \ref s == str && \ref length == len + \note Constant complexity. + */ +#endif + GenericStringRef(const CharType* str, SizeType len) + : s(RAPIDJSON_LIKELY(str) ? str : emptyString), length(len) { RAPIDJSON_ASSERT(str != 0 || len == 0u); } + + GenericStringRef(const GenericStringRef& rhs) : s(rhs.s), length(rhs.length) {} + + //! implicit conversion to plain CharType pointer + operator const Ch *() const { return s; } + + const Ch* const s; //!< plain CharType pointer + const SizeType length; //!< length of the string (excluding the trailing NULL terminator) + +private: + SizeType NotNullStrLen(const CharType* str) { + RAPIDJSON_ASSERT(str != 0); + return internal::StrLen(str); + } + + /// Empty string - used when passing in a NULL pointer + static const Ch emptyString[]; + + //! Disallow construction from non-const array + template + GenericStringRef(CharType (&str)[N]) /* = delete */; + //! Copy assignment operator not permitted - immutable type + GenericStringRef& operator=(const GenericStringRef& rhs) /* = delete */; +}; + +template +const CharType GenericStringRef::emptyString[] = { CharType() }; + +//! Mark a character pointer as constant string +/*! Mark a plain character pointer as a "string literal". This function + can be used to avoid copying a character string to be referenced as a + value in a JSON GenericValue object, if the string's lifetime is known + to be valid long enough. + \tparam CharType Character type of the string + \param str Constant string, lifetime assumed to be longer than the use of the string in e.g. a GenericValue + \return GenericStringRef string reference object + \relatesalso GenericStringRef + + \see GenericValue::GenericValue(StringRefType), GenericValue::operator=(StringRefType), GenericValue::SetString(StringRefType), GenericValue::PushBack(StringRefType, Allocator&), GenericValue::AddMember +*/ +template +inline GenericStringRef StringRef(const CharType* str) { + return GenericStringRef(str); +} + +//! Mark a character pointer as constant string +/*! Mark a plain character pointer as a "string literal". This function + can be used to avoid copying a character string to be referenced as a + value in a JSON GenericValue object, if the string's lifetime is known + to be valid long enough. + + This version has better performance with supplied length, and also + supports string containing null characters. + + \tparam CharType character type of the string + \param str Constant string, lifetime assumed to be longer than the use of the string in e.g. a GenericValue + \param length The length of source string. + \return GenericStringRef string reference object + \relatesalso GenericStringRef +*/ +template +inline GenericStringRef StringRef(const CharType* str, size_t length) { + return GenericStringRef(str, SizeType(length)); +} + +#if RAPIDJSON_HAS_STDSTRING +//! Mark a string object as constant string +/*! Mark a string object (e.g. \c std::string) as a "string literal". + This function can be used to avoid copying a string to be referenced as a + value in a JSON GenericValue object, if the string's lifetime is known + to be valid long enough. + + \tparam CharType character type of the string + \param str Constant string, lifetime assumed to be longer than the use of the string in e.g. a GenericValue + \return GenericStringRef string reference object + \relatesalso GenericStringRef + \note Requires the definition of the preprocessor symbol \ref RAPIDJSON_HAS_STDSTRING. +*/ +template +inline GenericStringRef StringRef(const std::basic_string& str) { + return GenericStringRef(str.data(), SizeType(str.size())); +} +#endif + +/////////////////////////////////////////////////////////////////////////////// +// GenericValue type traits +namespace internal { + +template +struct IsGenericValueImpl : FalseType {}; + +// select candidates according to nested encoding and allocator types +template struct IsGenericValueImpl::Type, typename Void::Type> + : IsBaseOf, T>::Type {}; + +// helper to match arbitrary GenericValue instantiations, including derived classes +template struct IsGenericValue : IsGenericValueImpl::Type {}; + +} // namespace internal + +/////////////////////////////////////////////////////////////////////////////// +// TypeHelper + +namespace internal { + +template +struct TypeHelper {}; + +template +struct TypeHelper { + static bool Is(const ValueType& v) { return v.IsBool(); } + static bool Get(const ValueType& v) { return v.GetBool(); } + static ValueType& Set(ValueType& v, bool data) { return v.SetBool(data); } + static ValueType& Set(ValueType& v, bool data, typename ValueType::AllocatorType&) { return v.SetBool(data); } +}; + +template +struct TypeHelper { + static bool Is(const ValueType& v) { return v.IsInt(); } + static int Get(const ValueType& v) { return v.GetInt(); } + static ValueType& Set(ValueType& v, int data) { return v.SetInt(data); } + static ValueType& Set(ValueType& v, int data, typename ValueType::AllocatorType&) { return v.SetInt(data); } +}; + +template +struct TypeHelper { + static bool Is(const ValueType& v) { return v.IsUint(); } + static unsigned Get(const ValueType& v) { return v.GetUint(); } + static ValueType& Set(ValueType& v, unsigned data) { return v.SetUint(data); } + static ValueType& Set(ValueType& v, unsigned data, typename ValueType::AllocatorType&) { return v.SetUint(data); } +}; + +#ifdef _MSC_VER +RAPIDJSON_STATIC_ASSERT(sizeof(long) == sizeof(int)); +template +struct TypeHelper { + static bool Is(const ValueType& v) { return v.IsInt(); } + static long Get(const ValueType& v) { return v.GetInt(); } + static ValueType& Set(ValueType& v, long data) { return v.SetInt(data); } + static ValueType& Set(ValueType& v, long data, typename ValueType::AllocatorType&) { return v.SetInt(data); } +}; + +RAPIDJSON_STATIC_ASSERT(sizeof(unsigned long) == sizeof(unsigned)); +template +struct TypeHelper { + static bool Is(const ValueType& v) { return v.IsUint(); } + static unsigned long Get(const ValueType& v) { return v.GetUint(); } + static ValueType& Set(ValueType& v, unsigned long data) { return v.SetUint(data); } + static ValueType& Set(ValueType& v, unsigned long data, typename ValueType::AllocatorType&) { return v.SetUint(data); } +}; +#endif + +template +struct TypeHelper { + static bool Is(const ValueType& v) { return v.IsInt64(); } + static int64_t Get(const ValueType& v) { return v.GetInt64(); } + static ValueType& Set(ValueType& v, int64_t data) { return v.SetInt64(data); } + static ValueType& Set(ValueType& v, int64_t data, typename ValueType::AllocatorType&) { return v.SetInt64(data); } +}; + +template +struct TypeHelper { + static bool Is(const ValueType& v) { return v.IsUint64(); } + static uint64_t Get(const ValueType& v) { return v.GetUint64(); } + static ValueType& Set(ValueType& v, uint64_t data) { return v.SetUint64(data); } + static ValueType& Set(ValueType& v, uint64_t data, typename ValueType::AllocatorType&) { return v.SetUint64(data); } +}; + +template +struct TypeHelper { + static bool Is(const ValueType& v) { return v.IsDouble(); } + static double Get(const ValueType& v) { return v.GetDouble(); } + static ValueType& Set(ValueType& v, double data) { return v.SetDouble(data); } + static ValueType& Set(ValueType& v, double data, typename ValueType::AllocatorType&) { return v.SetDouble(data); } +}; + +template +struct TypeHelper { + static bool Is(const ValueType& v) { return v.IsFloat(); } + static float Get(const ValueType& v) { return v.GetFloat(); } + static ValueType& Set(ValueType& v, float data) { return v.SetFloat(data); } + static ValueType& Set(ValueType& v, float data, typename ValueType::AllocatorType&) { return v.SetFloat(data); } +}; + +template +struct TypeHelper { + typedef const typename ValueType::Ch* StringType; + static bool Is(const ValueType& v) { return v.IsString(); } + static StringType Get(const ValueType& v) { return v.GetString(); } + static ValueType& Set(ValueType& v, const StringType data) { return v.SetString(typename ValueType::StringRefType(data)); } + static ValueType& Set(ValueType& v, const StringType data, typename ValueType::AllocatorType& a) { return v.SetString(data, a); } +}; + +#if RAPIDJSON_HAS_STDSTRING +template +struct TypeHelper > { + typedef std::basic_string StringType; + static bool Is(const ValueType& v) { return v.IsString(); } + static StringType Get(const ValueType& v) { return StringType(v.GetString(), v.GetStringLength()); } + static ValueType& Set(ValueType& v, const StringType& data, typename ValueType::AllocatorType& a) { return v.SetString(data, a); } +}; +#endif + +template +struct TypeHelper { + typedef typename ValueType::Array ArrayType; + static bool Is(const ValueType& v) { return v.IsArray(); } + static ArrayType Get(ValueType& v) { return v.GetArray(); } + static ValueType& Set(ValueType& v, ArrayType data) { return v = data; } + static ValueType& Set(ValueType& v, ArrayType data, typename ValueType::AllocatorType&) { return v = data; } +}; + +template +struct TypeHelper { + typedef typename ValueType::ConstArray ArrayType; + static bool Is(const ValueType& v) { return v.IsArray(); } + static ArrayType Get(const ValueType& v) { return v.GetArray(); } +}; + +template +struct TypeHelper { + typedef typename ValueType::Object ObjectType; + static bool Is(const ValueType& v) { return v.IsObject(); } + static ObjectType Get(ValueType& v) { return v.GetObject(); } + static ValueType& Set(ValueType& v, ObjectType data) { return v = data; } + static ValueType& Set(ValueType& v, ObjectType data, typename ValueType::AllocatorType&) { return v = data; } +}; + +template +struct TypeHelper { + typedef typename ValueType::ConstObject ObjectType; + static bool Is(const ValueType& v) { return v.IsObject(); } + static ObjectType Get(const ValueType& v) { return v.GetObject(); } +}; + +} // namespace internal + +// Forward declarations +template class GenericArray; +template class GenericObject; + +/////////////////////////////////////////////////////////////////////////////// +// GenericValue + +//! Represents a JSON value. Use Value for UTF8 encoding and default allocator. +/*! + A JSON value can be one of 7 types. This class is a variant type supporting + these types. + + Use the Value if UTF8 and default allocator + + \tparam Encoding Encoding of the value. (Even non-string values need to have the same encoding in a document) + \tparam Allocator Allocator type for allocating memory of object, array and string. +*/ +template +class GenericValue { +public: + //! Name-value pair in an object. + typedef GenericMember Member; + typedef Encoding EncodingType; //!< Encoding type from template parameter. + typedef Allocator AllocatorType; //!< Allocator type from template parameter. + typedef typename Encoding::Ch Ch; //!< Character type derived from Encoding. + typedef GenericStringRef StringRefType; //!< Reference to a constant string + typedef typename GenericMemberIterator::Iterator MemberIterator; //!< Member iterator for iterating in object. + typedef typename GenericMemberIterator::Iterator ConstMemberIterator; //!< Constant member iterator for iterating in object. + typedef GenericValue* ValueIterator; //!< Value iterator for iterating in array. + typedef const GenericValue* ConstValueIterator; //!< Constant value iterator for iterating in array. + typedef GenericValue ValueType; //!< Value type of itself. + typedef GenericArray Array; + typedef GenericArray ConstArray; + typedef GenericObject Object; + typedef GenericObject ConstObject; + + //!@name Constructors and destructor. + //@{ + + //! Default constructor creates a null value. + GenericValue() RAPIDJSON_NOEXCEPT : data_() { data_.f.flags = kNullFlag; } + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + //! Move constructor in C++11 + GenericValue(GenericValue&& rhs) RAPIDJSON_NOEXCEPT : data_(rhs.data_) { + rhs.data_.f.flags = kNullFlag; // give up contents + } +#endif + +private: + //! Copy constructor is not permitted. + GenericValue(const GenericValue& rhs); + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + //! Moving from a GenericDocument is not permitted. + template + GenericValue(GenericDocument&& rhs); + + //! Move assignment from a GenericDocument is not permitted. + template + GenericValue& operator=(GenericDocument&& rhs); +#endif + +public: + + //! Constructor with JSON value type. + /*! This creates a Value of specified type with default content. + \param type Type of the value. + \note Default content for number is zero. + */ + explicit GenericValue(Type type) RAPIDJSON_NOEXCEPT : data_() { + static const uint16_t defaultFlags[] = { + kNullFlag, kFalseFlag, kTrueFlag, kObjectFlag, kArrayFlag, kShortStringFlag, + kNumberAnyFlag + }; + RAPIDJSON_NOEXCEPT_ASSERT(type >= kNullType && type <= kNumberType); + data_.f.flags = defaultFlags[type]; + + // Use ShortString to store empty string. + if (type == kStringType) + data_.ss.SetLength(0); + } + + //! Explicit copy constructor (with allocator) + /*! Creates a copy of a Value by using the given Allocator + \tparam SourceAllocator allocator of \c rhs + \param rhs Value to copy from (read-only) + \param allocator Allocator for allocating copied elements and buffers. Commonly use GenericDocument::GetAllocator(). + \param copyConstStrings Force copying of constant strings (e.g. referencing an in-situ buffer) + \see CopyFrom() + */ + template + GenericValue(const GenericValue& rhs, Allocator& allocator, bool copyConstStrings = false) { + switch (rhs.GetType()) { + case kObjectType: + DoCopyMembers(rhs, allocator, copyConstStrings); + break; + case kArrayType: { + SizeType count = rhs.data_.a.size; + GenericValue* le = reinterpret_cast(allocator.Malloc(count * sizeof(GenericValue))); + const GenericValue* re = rhs.GetElementsPointer(); + for (SizeType i = 0; i < count; i++) + new (&le[i]) GenericValue(re[i], allocator, copyConstStrings); + data_.f.flags = kArrayFlag; + data_.a.size = data_.a.capacity = count; + SetElementsPointer(le); + } + break; + case kStringType: + if (rhs.data_.f.flags == kConstStringFlag && !copyConstStrings) { + data_.f.flags = rhs.data_.f.flags; + data_ = *reinterpret_cast(&rhs.data_); + } + else + SetStringRaw(StringRef(rhs.GetString(), rhs.GetStringLength()), allocator); + break; + default: + data_.f.flags = rhs.data_.f.flags; + data_ = *reinterpret_cast(&rhs.data_); + break; + } + } + + //! Constructor for boolean value. + /*! \param b Boolean value + \note This constructor is limited to \em real boolean values and rejects + implicitly converted types like arbitrary pointers. Use an explicit cast + to \c bool, if you want to construct a boolean JSON value in such cases. + */ +#ifndef RAPIDJSON_DOXYGEN_RUNNING // hide SFINAE from Doxygen + template + explicit GenericValue(T b, RAPIDJSON_ENABLEIF((internal::IsSame))) RAPIDJSON_NOEXCEPT // See #472 +#else + explicit GenericValue(bool b) RAPIDJSON_NOEXCEPT +#endif + : data_() { + // safe-guard against failing SFINAE + RAPIDJSON_STATIC_ASSERT((internal::IsSame::Value)); + data_.f.flags = b ? kTrueFlag : kFalseFlag; + } + + //! Constructor for int value. + explicit GenericValue(int i) RAPIDJSON_NOEXCEPT : data_() { + data_.n.i64 = i; + data_.f.flags = (i >= 0) ? (kNumberIntFlag | kUintFlag | kUint64Flag) : kNumberIntFlag; + } + + //! Constructor for unsigned value. + explicit GenericValue(unsigned u) RAPIDJSON_NOEXCEPT : data_() { + data_.n.u64 = u; + data_.f.flags = (u & 0x80000000) ? kNumberUintFlag : (kNumberUintFlag | kIntFlag | kInt64Flag); + } + + //! Constructor for int64_t value. + explicit GenericValue(int64_t i64) RAPIDJSON_NOEXCEPT : data_() { + data_.n.i64 = i64; + data_.f.flags = kNumberInt64Flag; + if (i64 >= 0) { + data_.f.flags |= kNumberUint64Flag; + if (!(static_cast(i64) & RAPIDJSON_UINT64_C2(0xFFFFFFFF, 0x00000000))) + data_.f.flags |= kUintFlag; + if (!(static_cast(i64) & RAPIDJSON_UINT64_C2(0xFFFFFFFF, 0x80000000))) + data_.f.flags |= kIntFlag; + } + else if (i64 >= static_cast(RAPIDJSON_UINT64_C2(0xFFFFFFFF, 0x80000000))) + data_.f.flags |= kIntFlag; + } + + //! Constructor for uint64_t value. + explicit GenericValue(uint64_t u64) RAPIDJSON_NOEXCEPT : data_() { + data_.n.u64 = u64; + data_.f.flags = kNumberUint64Flag; + if (!(u64 & RAPIDJSON_UINT64_C2(0x80000000, 0x00000000))) + data_.f.flags |= kInt64Flag; + if (!(u64 & RAPIDJSON_UINT64_C2(0xFFFFFFFF, 0x00000000))) + data_.f.flags |= kUintFlag; + if (!(u64 & RAPIDJSON_UINT64_C2(0xFFFFFFFF, 0x80000000))) + data_.f.flags |= kIntFlag; + } + + //! Constructor for double value. + explicit GenericValue(double d) RAPIDJSON_NOEXCEPT : data_() { data_.n.d = d; data_.f.flags = kNumberDoubleFlag; } + + //! Constructor for float value. + explicit GenericValue(float f) RAPIDJSON_NOEXCEPT : data_() { data_.n.d = static_cast(f); data_.f.flags = kNumberDoubleFlag; } + + //! Constructor for constant string (i.e. do not make a copy of string) + GenericValue(const Ch* s, SizeType length) RAPIDJSON_NOEXCEPT : data_() { SetStringRaw(StringRef(s, length)); } + + //! Constructor for constant string (i.e. do not make a copy of string) + explicit GenericValue(StringRefType s) RAPIDJSON_NOEXCEPT : data_() { SetStringRaw(s); } + + //! Constructor for copy-string (i.e. do make a copy of string) + GenericValue(const Ch* s, SizeType length, Allocator& allocator) : data_() { SetStringRaw(StringRef(s, length), allocator); } + + //! Constructor for copy-string (i.e. do make a copy of string) + GenericValue(const Ch*s, Allocator& allocator) : data_() { SetStringRaw(StringRef(s), allocator); } + +#if RAPIDJSON_HAS_STDSTRING + //! Constructor for copy-string from a string object (i.e. do make a copy of string) + /*! \note Requires the definition of the preprocessor symbol \ref RAPIDJSON_HAS_STDSTRING. + */ + GenericValue(const std::basic_string& s, Allocator& allocator) : data_() { SetStringRaw(StringRef(s), allocator); } +#endif + + //! Constructor for Array. + /*! + \param a An array obtained by \c GetArray(). + \note \c Array is always pass-by-value. + \note the source array is moved into this value and the sourec array becomes empty. + */ + GenericValue(Array a) RAPIDJSON_NOEXCEPT : data_(a.value_.data_) { + a.value_.data_ = Data(); + a.value_.data_.f.flags = kArrayFlag; + } + + //! Constructor for Object. + /*! + \param o An object obtained by \c GetObject(). + \note \c Object is always pass-by-value. + \note the source object is moved into this value and the sourec object becomes empty. + */ + GenericValue(Object o) RAPIDJSON_NOEXCEPT : data_(o.value_.data_) { + o.value_.data_ = Data(); + o.value_.data_.f.flags = kObjectFlag; + } + + //! Destructor. + /*! Need to destruct elements of array, members of object, or copy-string. + */ + ~GenericValue() { + // With RAPIDJSON_USE_MEMBERSMAP, the maps need to be destroyed to release + // their Allocator if it's refcounted (e.g. MemoryPoolAllocator). + if (Allocator::kNeedFree || (RAPIDJSON_USE_MEMBERSMAP+0 && + internal::IsRefCounted::Value)) { + switch(data_.f.flags) { + case kArrayFlag: + { + GenericValue* e = GetElementsPointer(); + for (GenericValue* v = e; v != e + data_.a.size; ++v) + v->~GenericValue(); + if (Allocator::kNeedFree) { // Shortcut by Allocator's trait + Allocator::Free(e); + } + } + break; + + case kObjectFlag: + DoFreeMembers(); + break; + + case kCopyStringFlag: + if (Allocator::kNeedFree) { // Shortcut by Allocator's trait + Allocator::Free(const_cast(GetStringPointer())); + } + break; + + default: + break; // Do nothing for other types. + } + } + } + + //@} + + //!@name Assignment operators + //@{ + + //! Assignment with move semantics. + /*! \param rhs Source of the assignment. It will become a null value after assignment. + */ + GenericValue& operator=(GenericValue& rhs) RAPIDJSON_NOEXCEPT { + if (RAPIDJSON_LIKELY(this != &rhs)) { + // Can't destroy "this" before assigning "rhs", otherwise "rhs" + // could be used after free if it's an sub-Value of "this", + // hence the temporary danse. + GenericValue temp; + temp.RawAssign(rhs); + this->~GenericValue(); + RawAssign(temp); + } + return *this; + } + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + //! Move assignment in C++11 + GenericValue& operator=(GenericValue&& rhs) RAPIDJSON_NOEXCEPT { + return *this = rhs.Move(); + } +#endif + + //! Assignment of constant string reference (no copy) + /*! \param str Constant string reference to be assigned + \note This overload is needed to avoid clashes with the generic primitive type assignment overload below. + \see GenericStringRef, operator=(T) + */ + GenericValue& operator=(StringRefType str) RAPIDJSON_NOEXCEPT { + GenericValue s(str); + return *this = s; + } + + //! Assignment with primitive types. + /*! \tparam T Either \ref Type, \c int, \c unsigned, \c int64_t, \c uint64_t + \param value The value to be assigned. + + \note The source type \c T explicitly disallows all pointer types, + especially (\c const) \ref Ch*. This helps avoiding implicitly + referencing character strings with insufficient lifetime, use + \ref SetString(const Ch*, Allocator&) (for copying) or + \ref StringRef() (to explicitly mark the pointer as constant) instead. + All other pointer types would implicitly convert to \c bool, + use \ref SetBool() instead. + */ + template + RAPIDJSON_DISABLEIF_RETURN((internal::IsPointer), (GenericValue&)) + operator=(T value) { + GenericValue v(value); + return *this = v; + } + + //! Deep-copy assignment from Value + /*! Assigns a \b copy of the Value to the current Value object + \tparam SourceAllocator Allocator type of \c rhs + \param rhs Value to copy from (read-only) + \param allocator Allocator to use for copying + \param copyConstStrings Force copying of constant strings (e.g. referencing an in-situ buffer) + */ + template + GenericValue& CopyFrom(const GenericValue& rhs, Allocator& allocator, bool copyConstStrings = false) { + RAPIDJSON_ASSERT(static_cast(this) != static_cast(&rhs)); + this->~GenericValue(); + new (this) GenericValue(rhs, allocator, copyConstStrings); + return *this; + } + + //! Exchange the contents of this value with those of other. + /*! + \param other Another value. + \note Constant complexity. + */ + GenericValue& Swap(GenericValue& other) RAPIDJSON_NOEXCEPT { + GenericValue temp; + temp.RawAssign(*this); + RawAssign(other); + other.RawAssign(temp); + return *this; + } + + //! free-standing swap function helper + /*! + Helper function to enable support for common swap implementation pattern based on \c std::swap: + \code + void swap(MyClass& a, MyClass& b) { + using std::swap; + swap(a.value, b.value); + // ... + } + \endcode + \see Swap() + */ + friend inline void swap(GenericValue& a, GenericValue& b) RAPIDJSON_NOEXCEPT { a.Swap(b); } + + //! Prepare Value for move semantics + /*! \return *this */ + GenericValue& Move() RAPIDJSON_NOEXCEPT { return *this; } + //@} + + //!@name Equal-to and not-equal-to operators + //@{ + //! Equal-to operator + /*! + \note If an object contains duplicated named member, comparing equality with any object is always \c false. + \note Complexity is quadratic in Object's member number and linear for the rest (number of all values in the subtree and total lengths of all strings). + */ + template + bool operator==(const GenericValue& rhs) const { + typedef GenericValue RhsType; + if (GetType() != rhs.GetType()) + return false; + + switch (GetType()) { + case kObjectType: // Warning: O(n^2) inner-loop + if (data_.o.size != rhs.data_.o.size) + return false; + for (ConstMemberIterator lhsMemberItr = MemberBegin(); lhsMemberItr != MemberEnd(); ++lhsMemberItr) { + typename RhsType::ConstMemberIterator rhsMemberItr = rhs.FindMember(lhsMemberItr->name); + if (rhsMemberItr == rhs.MemberEnd() || (!(lhsMemberItr->value == rhsMemberItr->value))) + return false; + } + return true; + + case kArrayType: + if (data_.a.size != rhs.data_.a.size) + return false; + for (SizeType i = 0; i < data_.a.size; i++) + if (!((*this)[i] == rhs[i])) + return false; + return true; + + case kStringType: + return StringEqual(rhs); + + case kNumberType: + if (IsDouble() || rhs.IsDouble()) { + double a = GetDouble(); // May convert from integer to double. + double b = rhs.GetDouble(); // Ditto + return a >= b && a <= b; // Prevent -Wfloat-equal + } + else + return data_.n.u64 == rhs.data_.n.u64; + + default: + return true; + } + } + + //! Equal-to operator with const C-string pointer + bool operator==(const Ch* rhs) const { return *this == GenericValue(StringRef(rhs)); } + +#if RAPIDJSON_HAS_STDSTRING + //! Equal-to operator with string object + /*! \note Requires the definition of the preprocessor symbol \ref RAPIDJSON_HAS_STDSTRING. + */ + bool operator==(const std::basic_string& rhs) const { return *this == GenericValue(StringRef(rhs)); } +#endif + + //! Equal-to operator with primitive types + /*! \tparam T Either \ref Type, \c int, \c unsigned, \c int64_t, \c uint64_t, \c double, \c true, \c false + */ + template RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr,internal::IsGenericValue >), (bool)) operator==(const T& rhs) const { return *this == GenericValue(rhs); } + +#ifndef __cpp_impl_three_way_comparison + //! Not-equal-to operator + /*! \return !(*this == rhs) + */ + template + bool operator!=(const GenericValue& rhs) const { return !(*this == rhs); } + + //! Not-equal-to operator with const C-string pointer + bool operator!=(const Ch* rhs) const { return !(*this == rhs); } + + //! Not-equal-to operator with arbitrary types + /*! \return !(*this == rhs) + */ + template RAPIDJSON_DISABLEIF_RETURN((internal::IsGenericValue), (bool)) operator!=(const T& rhs) const { return !(*this == rhs); } + + //! Equal-to operator with arbitrary types (symmetric version) + /*! \return (rhs == lhs) + */ + template friend RAPIDJSON_DISABLEIF_RETURN((internal::IsGenericValue), (bool)) operator==(const T& lhs, const GenericValue& rhs) { return rhs == lhs; } + + //! Not-Equal-to operator with arbitrary types (symmetric version) + /*! \return !(rhs == lhs) + */ + template friend RAPIDJSON_DISABLEIF_RETURN((internal::IsGenericValue), (bool)) operator!=(const T& lhs, const GenericValue& rhs) { return !(rhs == lhs); } + //@} +#endif + + //!@name Type + //@{ + + Type GetType() const { return static_cast(data_.f.flags & kTypeMask); } + bool IsNull() const { return data_.f.flags == kNullFlag; } + bool IsFalse() const { return data_.f.flags == kFalseFlag; } + bool IsTrue() const { return data_.f.flags == kTrueFlag; } + bool IsBool() const { return (data_.f.flags & kBoolFlag) != 0; } + bool IsObject() const { return data_.f.flags == kObjectFlag; } + bool IsArray() const { return data_.f.flags == kArrayFlag; } + bool IsNumber() const { return (data_.f.flags & kNumberFlag) != 0; } + bool IsInt() const { return (data_.f.flags & kIntFlag) != 0; } + bool IsUint() const { return (data_.f.flags & kUintFlag) != 0; } + bool IsInt64() const { return (data_.f.flags & kInt64Flag) != 0; } + bool IsUint64() const { return (data_.f.flags & kUint64Flag) != 0; } + bool IsDouble() const { return (data_.f.flags & kDoubleFlag) != 0; } + bool IsString() const { return (data_.f.flags & kStringFlag) != 0; } + + // Checks whether a number can be losslessly converted to a double. + bool IsLosslessDouble() const { + if (!IsNumber()) return false; + if (IsUint64()) { + uint64_t u = GetUint64(); + volatile double d = static_cast(u); + return (d >= 0.0) + && (d < static_cast((std::numeric_limits::max)())) + && (u == static_cast(d)); + } + if (IsInt64()) { + int64_t i = GetInt64(); + volatile double d = static_cast(i); + return (d >= static_cast((std::numeric_limits::min)())) + && (d < static_cast((std::numeric_limits::max)())) + && (i == static_cast(d)); + } + return true; // double, int, uint are always lossless + } + + // Checks whether a number is a float (possible lossy). + bool IsFloat() const { + if ((data_.f.flags & kDoubleFlag) == 0) + return false; + double d = GetDouble(); + return d >= -3.4028234e38 && d <= 3.4028234e38; + } + // Checks whether a number can be losslessly converted to a float. + bool IsLosslessFloat() const { + if (!IsNumber()) return false; + double a = GetDouble(); + if (a < static_cast(-(std::numeric_limits::max)()) + || a > static_cast((std::numeric_limits::max)())) + return false; + double b = static_cast(static_cast(a)); + return a >= b && a <= b; // Prevent -Wfloat-equal + } + + //@} + + //!@name Null + //@{ + + GenericValue& SetNull() { this->~GenericValue(); new (this) GenericValue(); return *this; } + + //@} + + //!@name Bool + //@{ + + bool GetBool() const { RAPIDJSON_ASSERT(IsBool()); return data_.f.flags == kTrueFlag; } + //!< Set boolean value + /*! \post IsBool() == true */ + GenericValue& SetBool(bool b) { this->~GenericValue(); new (this) GenericValue(b); return *this; } + + //@} + + //!@name Object + //@{ + + //! Set this value as an empty object. + /*! \post IsObject() == true */ + GenericValue& SetObject() { this->~GenericValue(); new (this) GenericValue(kObjectType); return *this; } + + //! Get the number of members in the object. + SizeType MemberCount() const { RAPIDJSON_ASSERT(IsObject()); return data_.o.size; } + + //! Get the capacity of object. + SizeType MemberCapacity() const { RAPIDJSON_ASSERT(IsObject()); return data_.o.capacity; } + + //! Check whether the object is empty. + bool ObjectEmpty() const { RAPIDJSON_ASSERT(IsObject()); return data_.o.size == 0; } + + //! Get a value from an object associated with the name. + /*! \pre IsObject() == true + \tparam T Either \c Ch or \c const \c Ch (template used for disambiguation with \ref operator[](SizeType)) + \note In version 0.1x, if the member is not found, this function returns a null value. This makes issue 7. + Since 0.2, if the name is not correct, it will assert. + If user is unsure whether a member exists, user should use HasMember() first. + A better approach is to use FindMember(). + \note Linear time complexity. + */ + template + RAPIDJSON_DISABLEIF_RETURN((internal::NotExpr::Type, Ch> >),(GenericValue&)) operator[](T* name) { + GenericValue n(StringRef(name)); + return (*this)[n]; + } + template + RAPIDJSON_DISABLEIF_RETURN((internal::NotExpr::Type, Ch> >),(const GenericValue&)) operator[](T* name) const { return const_cast(*this)[name]; } + + //! Get a value from an object associated with the name. + /*! \pre IsObject() == true + \tparam SourceAllocator Allocator of the \c name value + + \note Compared to \ref operator[](T*), this version is faster because it does not need a StrLen(). + And it can also handle strings with embedded null characters. + + \note Linear time complexity. + */ + template + GenericValue& operator[](const GenericValue& name) { + MemberIterator member = FindMember(name); + if (member != MemberEnd()) + return member->value; + else { + RAPIDJSON_ASSERT(false); // see above note + +#if RAPIDJSON_HAS_CXX11 + // Use thread-local storage to prevent races between threads. + // Use static buffer and placement-new to prevent destruction, with + // alignas() to ensure proper alignment. + alignas(GenericValue) thread_local static char buffer[sizeof(GenericValue)]; + return *new (buffer) GenericValue(); +#elif defined(_MSC_VER) && _MSC_VER < 1900 + // There's no way to solve both thread locality and proper alignment + // simultaneously. + __declspec(thread) static char buffer[sizeof(GenericValue)]; + return *new (buffer) GenericValue(); +#elif defined(__GNUC__) || defined(__clang__) + // This will generate -Wexit-time-destructors in clang, but that's + // better than having under-alignment. + __thread static GenericValue buffer; + return buffer; +#else + // Don't know what compiler this is, so don't know how to ensure + // thread-locality. + static GenericValue buffer; + return buffer; +#endif + } + } + template + const GenericValue& operator[](const GenericValue& name) const { return const_cast(*this)[name]; } + +#if RAPIDJSON_HAS_STDSTRING + //! Get a value from an object associated with name (string object). + GenericValue& operator[](const std::basic_string& name) { return (*this)[GenericValue(StringRef(name))]; } + const GenericValue& operator[](const std::basic_string& name) const { return (*this)[GenericValue(StringRef(name))]; } +#endif + + //! Const member iterator + /*! \pre IsObject() == true */ + ConstMemberIterator MemberBegin() const { RAPIDJSON_ASSERT(IsObject()); return ConstMemberIterator(GetMembersPointer()); } + //! Const \em past-the-end member iterator + /*! \pre IsObject() == true */ + ConstMemberIterator MemberEnd() const { RAPIDJSON_ASSERT(IsObject()); return ConstMemberIterator(GetMembersPointer() + data_.o.size); } + //! Member iterator + /*! \pre IsObject() == true */ + MemberIterator MemberBegin() { RAPIDJSON_ASSERT(IsObject()); return MemberIterator(GetMembersPointer()); } + //! \em Past-the-end member iterator + /*! \pre IsObject() == true */ + MemberIterator MemberEnd() { RAPIDJSON_ASSERT(IsObject()); return MemberIterator(GetMembersPointer() + data_.o.size); } + + //! Request the object to have enough capacity to store members. + /*! \param newCapacity The capacity that the object at least need to have. + \param allocator Allocator for reallocating memory. It must be the same one as used before. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \note Linear time complexity. + */ + GenericValue& MemberReserve(SizeType newCapacity, Allocator &allocator) { + RAPIDJSON_ASSERT(IsObject()); + DoReserveMembers(newCapacity, allocator); + return *this; + } + + //! Check whether a member exists in the object. + /*! + \param name Member name to be searched. + \pre IsObject() == true + \return Whether a member with that name exists. + \note It is better to use FindMember() directly if you need the obtain the value as well. + \note Linear time complexity. + */ + bool HasMember(const Ch* name) const { return FindMember(name) != MemberEnd(); } + +#if RAPIDJSON_HAS_STDSTRING + //! Check whether a member exists in the object with string object. + /*! + \param name Member name to be searched. + \pre IsObject() == true + \return Whether a member with that name exists. + \note It is better to use FindMember() directly if you need the obtain the value as well. + \note Linear time complexity. + */ + bool HasMember(const std::basic_string& name) const { return FindMember(name) != MemberEnd(); } +#endif + + //! Check whether a member exists in the object with GenericValue name. + /*! + This version is faster because it does not need a StrLen(). It can also handle string with null character. + \param name Member name to be searched. + \pre IsObject() == true + \return Whether a member with that name exists. + \note It is better to use FindMember() directly if you need the obtain the value as well. + \note Linear time complexity. + */ + template + bool HasMember(const GenericValue& name) const { return FindMember(name) != MemberEnd(); } + + //! Find member by name. + /*! + \param name Member name to be searched. + \pre IsObject() == true + \return Iterator to member, if it exists. + Otherwise returns \ref MemberEnd(). + + \note Earlier versions of Rapidjson returned a \c NULL pointer, in case + the requested member doesn't exist. For consistency with e.g. + \c std::map, this has been changed to MemberEnd() now. + \note Linear time complexity. + */ + MemberIterator FindMember(const Ch* name) { + GenericValue n(StringRef(name)); + return FindMember(n); + } + + ConstMemberIterator FindMember(const Ch* name) const { return const_cast(*this).FindMember(name); } + + //! Find member by name. + /*! + This version is faster because it does not need a StrLen(). It can also handle string with null character. + \param name Member name to be searched. + \pre IsObject() == true + \return Iterator to member, if it exists. + Otherwise returns \ref MemberEnd(). + + \note Earlier versions of Rapidjson returned a \c NULL pointer, in case + the requested member doesn't exist. For consistency with e.g. + \c std::map, this has been changed to MemberEnd() now. + \note Linear time complexity. + */ + template + MemberIterator FindMember(const GenericValue& name) { + RAPIDJSON_ASSERT(IsObject()); + RAPIDJSON_ASSERT(name.IsString()); + return DoFindMember(name); + } + template ConstMemberIterator FindMember(const GenericValue& name) const { return const_cast(*this).FindMember(name); } + +#if RAPIDJSON_HAS_STDSTRING + //! Find member by string object name. + /*! + \param name Member name to be searched. + \pre IsObject() == true + \return Iterator to member, if it exists. + Otherwise returns \ref MemberEnd(). + */ + MemberIterator FindMember(const std::basic_string& name) { return FindMember(GenericValue(StringRef(name))); } + ConstMemberIterator FindMember(const std::basic_string& name) const { return FindMember(GenericValue(StringRef(name))); } +#endif + + //! Add a member (name-value pair) to the object. + /*! \param name A string value as name of member. + \param value Value of any type. + \param allocator Allocator for reallocating memory. It must be the same one as used before. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \note The ownership of \c name and \c value will be transferred to this object on success. + \pre IsObject() && name.IsString() + \post name.IsNull() && value.IsNull() + \note Amortized Constant time complexity. + */ + GenericValue& AddMember(GenericValue& name, GenericValue& value, Allocator& allocator) { + RAPIDJSON_ASSERT(IsObject()); + RAPIDJSON_ASSERT(name.IsString()); + DoAddMember(name, value, allocator); + return *this; + } + + //! Add a constant string value as member (name-value pair) to the object. + /*! \param name A string value as name of member. + \param value constant string reference as value of member. + \param allocator Allocator for reallocating memory. It must be the same one as used before. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \pre IsObject() + \note This overload is needed to avoid clashes with the generic primitive type AddMember(GenericValue&,T,Allocator&) overload below. + \note Amortized Constant time complexity. + */ + GenericValue& AddMember(GenericValue& name, StringRefType value, Allocator& allocator) { + GenericValue v(value); + return AddMember(name, v, allocator); + } + +#if RAPIDJSON_HAS_STDSTRING + //! Add a string object as member (name-value pair) to the object. + /*! \param name A string value as name of member. + \param value constant string reference as value of member. + \param allocator Allocator for reallocating memory. It must be the same one as used before. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \pre IsObject() + \note This overload is needed to avoid clashes with the generic primitive type AddMember(GenericValue&,T,Allocator&) overload below. + \note Amortized Constant time complexity. + */ + GenericValue& AddMember(GenericValue& name, std::basic_string& value, Allocator& allocator) { + GenericValue v(value, allocator); + return AddMember(name, v, allocator); + } +#endif + + //! Add any primitive value as member (name-value pair) to the object. + /*! \tparam T Either \ref Type, \c int, \c unsigned, \c int64_t, \c uint64_t + \param name A string value as name of member. + \param value Value of primitive type \c T as value of member + \param allocator Allocator for reallocating memory. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \pre IsObject() + + \note The source type \c T explicitly disallows all pointer types, + especially (\c const) \ref Ch*. This helps avoiding implicitly + referencing character strings with insufficient lifetime, use + \ref AddMember(StringRefType, GenericValue&, Allocator&) or \ref + AddMember(StringRefType, StringRefType, Allocator&). + All other pointer types would implicitly convert to \c bool, + use an explicit cast instead, if needed. + \note Amortized Constant time complexity. + */ + template + RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (GenericValue&)) + AddMember(GenericValue& name, T value, Allocator& allocator) { + GenericValue v(value); + return AddMember(name, v, allocator); + } + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + GenericValue& AddMember(GenericValue&& name, GenericValue&& value, Allocator& allocator) { + return AddMember(name, value, allocator); + } + GenericValue& AddMember(GenericValue&& name, GenericValue& value, Allocator& allocator) { + return AddMember(name, value, allocator); + } + GenericValue& AddMember(GenericValue& name, GenericValue&& value, Allocator& allocator) { + return AddMember(name, value, allocator); + } + GenericValue& AddMember(StringRefType name, GenericValue&& value, Allocator& allocator) { + GenericValue n(name); + return AddMember(n, value, allocator); + } +#endif // RAPIDJSON_HAS_CXX11_RVALUE_REFS + + + //! Add a member (name-value pair) to the object. + /*! \param name A constant string reference as name of member. + \param value Value of any type. + \param allocator Allocator for reallocating memory. It must be the same one as used before. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \note The ownership of \c value will be transferred to this object on success. + \pre IsObject() + \post value.IsNull() + \note Amortized Constant time complexity. + */ + GenericValue& AddMember(StringRefType name, GenericValue& value, Allocator& allocator) { + GenericValue n(name); + return AddMember(n, value, allocator); + } + + //! Add a constant string value as member (name-value pair) to the object. + /*! \param name A constant string reference as name of member. + \param value constant string reference as value of member. + \param allocator Allocator for reallocating memory. It must be the same one as used before. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \pre IsObject() + \note This overload is needed to avoid clashes with the generic primitive type AddMember(StringRefType,T,Allocator&) overload below. + \note Amortized Constant time complexity. + */ + GenericValue& AddMember(StringRefType name, StringRefType value, Allocator& allocator) { + GenericValue v(value); + return AddMember(name, v, allocator); + } + + //! Add any primitive value as member (name-value pair) to the object. + /*! \tparam T Either \ref Type, \c int, \c unsigned, \c int64_t, \c uint64_t + \param name A constant string reference as name of member. + \param value Value of primitive type \c T as value of member + \param allocator Allocator for reallocating memory. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \pre IsObject() + + \note The source type \c T explicitly disallows all pointer types, + especially (\c const) \ref Ch*. This helps avoiding implicitly + referencing character strings with insufficient lifetime, use + \ref AddMember(StringRefType, GenericValue&, Allocator&) or \ref + AddMember(StringRefType, StringRefType, Allocator&). + All other pointer types would implicitly convert to \c bool, + use an explicit cast instead, if needed. + \note Amortized Constant time complexity. + */ + template + RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (GenericValue&)) + AddMember(StringRefType name, T value, Allocator& allocator) { + GenericValue n(name); + return AddMember(n, value, allocator); + } + + //! Remove all members in the object. + /*! This function do not deallocate memory in the object, i.e. the capacity is unchanged. + \note Linear time complexity. + */ + void RemoveAllMembers() { + RAPIDJSON_ASSERT(IsObject()); + DoClearMembers(); + } + + //! Remove a member in object by its name. + /*! \param name Name of member to be removed. + \return Whether the member existed. + \note This function may reorder the object members. Use \ref + EraseMember(ConstMemberIterator) if you need to preserve the + relative order of the remaining members. + \note Linear time complexity. + */ + bool RemoveMember(const Ch* name) { + GenericValue n(StringRef(name)); + return RemoveMember(n); + } + +#if RAPIDJSON_HAS_STDSTRING + bool RemoveMember(const std::basic_string& name) { return RemoveMember(GenericValue(StringRef(name))); } +#endif + + template + bool RemoveMember(const GenericValue& name) { + MemberIterator m = FindMember(name); + if (m != MemberEnd()) { + RemoveMember(m); + return true; + } + else + return false; + } + + //! Remove a member in object by iterator. + /*! \param m member iterator (obtained by FindMember() or MemberBegin()). + \return the new iterator after removal. + \note This function may reorder the object members. Use \ref + EraseMember(ConstMemberIterator) if you need to preserve the + relative order of the remaining members. + \note Constant time complexity. + */ + MemberIterator RemoveMember(MemberIterator m) { + RAPIDJSON_ASSERT(IsObject()); + RAPIDJSON_ASSERT(data_.o.size > 0); + RAPIDJSON_ASSERT(GetMembersPointer() != 0); + RAPIDJSON_ASSERT(m >= MemberBegin() && m < MemberEnd()); + return DoRemoveMember(m); + } + + //! Remove a member from an object by iterator. + /*! \param pos iterator to the member to remove + \pre IsObject() == true && \ref MemberBegin() <= \c pos < \ref MemberEnd() + \return Iterator following the removed element. + If the iterator \c pos refers to the last element, the \ref MemberEnd() iterator is returned. + \note This function preserves the relative order of the remaining object + members. If you do not need this, use the more efficient \ref RemoveMember(MemberIterator). + \note Linear time complexity. + */ + MemberIterator EraseMember(ConstMemberIterator pos) { + return EraseMember(pos, pos +1); + } + + //! Remove members in the range [first, last) from an object. + /*! \param first iterator to the first member to remove + \param last iterator following the last member to remove + \pre IsObject() == true && \ref MemberBegin() <= \c first <= \c last <= \ref MemberEnd() + \return Iterator following the last removed element. + \note This function preserves the relative order of the remaining object + members. + \note Linear time complexity. + */ + MemberIterator EraseMember(ConstMemberIterator first, ConstMemberIterator last) { + RAPIDJSON_ASSERT(IsObject()); + RAPIDJSON_ASSERT(data_.o.size > 0); + RAPIDJSON_ASSERT(GetMembersPointer() != 0); + RAPIDJSON_ASSERT(first >= MemberBegin()); + RAPIDJSON_ASSERT(first <= last); + RAPIDJSON_ASSERT(last <= MemberEnd()); + return DoEraseMembers(first, last); + } + + //! Erase a member in object by its name. + /*! \param name Name of member to be removed. + \return Whether the member existed. + \note Linear time complexity. + */ + bool EraseMember(const Ch* name) { + GenericValue n(StringRef(name)); + return EraseMember(n); + } + +#if RAPIDJSON_HAS_STDSTRING + bool EraseMember(const std::basic_string& name) { return EraseMember(GenericValue(StringRef(name))); } +#endif + + template + bool EraseMember(const GenericValue& name) { + MemberIterator m = FindMember(name); + if (m != MemberEnd()) { + EraseMember(m); + return true; + } + else + return false; + } + + Object GetObject() { RAPIDJSON_ASSERT(IsObject()); return Object(*this); } + Object GetObj() { RAPIDJSON_ASSERT(IsObject()); return Object(*this); } + ConstObject GetObject() const { RAPIDJSON_ASSERT(IsObject()); return ConstObject(*this); } + ConstObject GetObj() const { RAPIDJSON_ASSERT(IsObject()); return ConstObject(*this); } + + //@} + + //!@name Array + //@{ + + //! Set this value as an empty array. + /*! \post IsArray == true */ + GenericValue& SetArray() { this->~GenericValue(); new (this) GenericValue(kArrayType); return *this; } + + //! Get the number of elements in array. + SizeType Size() const { RAPIDJSON_ASSERT(IsArray()); return data_.a.size; } + + //! Get the capacity of array. + SizeType Capacity() const { RAPIDJSON_ASSERT(IsArray()); return data_.a.capacity; } + + //! Check whether the array is empty. + bool Empty() const { RAPIDJSON_ASSERT(IsArray()); return data_.a.size == 0; } + + //! Remove all elements in the array. + /*! This function do not deallocate memory in the array, i.e. the capacity is unchanged. + \note Linear time complexity. + */ + void Clear() { + RAPIDJSON_ASSERT(IsArray()); + GenericValue* e = GetElementsPointer(); + for (GenericValue* v = e; v != e + data_.a.size; ++v) + v->~GenericValue(); + data_.a.size = 0; + } + + //! Get an element from array by index. + /*! \pre IsArray() == true + \param index Zero-based index of element. + \see operator[](T*) + */ + GenericValue& operator[](SizeType index) { + RAPIDJSON_ASSERT(IsArray()); + RAPIDJSON_ASSERT(index < data_.a.size); + return GetElementsPointer()[index]; + } + const GenericValue& operator[](SizeType index) const { return const_cast(*this)[index]; } + + //! Element iterator + /*! \pre IsArray() == true */ + ValueIterator Begin() { RAPIDJSON_ASSERT(IsArray()); return GetElementsPointer(); } + //! \em Past-the-end element iterator + /*! \pre IsArray() == true */ + ValueIterator End() { RAPIDJSON_ASSERT(IsArray()); return GetElementsPointer() + data_.a.size; } + //! Constant element iterator + /*! \pre IsArray() == true */ + ConstValueIterator Begin() const { return const_cast(*this).Begin(); } + //! Constant \em past-the-end element iterator + /*! \pre IsArray() == true */ + ConstValueIterator End() const { return const_cast(*this).End(); } + + //! Request the array to have enough capacity to store elements. + /*! \param newCapacity The capacity that the array at least need to have. + \param allocator Allocator for reallocating memory. It must be the same one as used before. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \note Linear time complexity. + */ + GenericValue& Reserve(SizeType newCapacity, Allocator &allocator) { + RAPIDJSON_ASSERT(IsArray()); + if (newCapacity > data_.a.capacity) { + SetElementsPointer(reinterpret_cast(allocator.Realloc(GetElementsPointer(), data_.a.capacity * sizeof(GenericValue), newCapacity * sizeof(GenericValue)))); + data_.a.capacity = newCapacity; + } + return *this; + } + + //! Append a GenericValue at the end of the array. + /*! \param value Value to be appended. + \param allocator Allocator for reallocating memory. It must be the same one as used before. Commonly use GenericDocument::GetAllocator(). + \pre IsArray() == true + \post value.IsNull() == true + \return The value itself for fluent API. + \note The ownership of \c value will be transferred to this array on success. + \note If the number of elements to be appended is known, calls Reserve() once first may be more efficient. + \note Amortized constant time complexity. + */ + GenericValue& PushBack(GenericValue& value, Allocator& allocator) { + RAPIDJSON_ASSERT(IsArray()); + if (data_.a.size >= data_.a.capacity) + Reserve(data_.a.capacity == 0 ? kDefaultArrayCapacity : (data_.a.capacity + (data_.a.capacity + 1) / 2), allocator); + GetElementsPointer()[data_.a.size++].RawAssign(value); + return *this; + } + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + GenericValue& PushBack(GenericValue&& value, Allocator& allocator) { + return PushBack(value, allocator); + } +#endif // RAPIDJSON_HAS_CXX11_RVALUE_REFS + + //! Append a constant string reference at the end of the array. + /*! \param value Constant string reference to be appended. + \param allocator Allocator for reallocating memory. It must be the same one used previously. Commonly use GenericDocument::GetAllocator(). + \pre IsArray() == true + \return The value itself for fluent API. + \note If the number of elements to be appended is known, calls Reserve() once first may be more efficient. + \note Amortized constant time complexity. + \see GenericStringRef + */ + GenericValue& PushBack(StringRefType value, Allocator& allocator) { + return (*this).template PushBack(value, allocator); + } + + //! Append a primitive value at the end of the array. + /*! \tparam T Either \ref Type, \c int, \c unsigned, \c int64_t, \c uint64_t + \param value Value of primitive type T to be appended. + \param allocator Allocator for reallocating memory. It must be the same one as used before. Commonly use GenericDocument::GetAllocator(). + \pre IsArray() == true + \return The value itself for fluent API. + \note If the number of elements to be appended is known, calls Reserve() once first may be more efficient. + + \note The source type \c T explicitly disallows all pointer types, + especially (\c const) \ref Ch*. This helps avoiding implicitly + referencing character strings with insufficient lifetime, use + \ref PushBack(GenericValue&, Allocator&) or \ref + PushBack(StringRefType, Allocator&). + All other pointer types would implicitly convert to \c bool, + use an explicit cast instead, if needed. + \note Amortized constant time complexity. + */ + template + RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (GenericValue&)) + PushBack(T value, Allocator& allocator) { + GenericValue v(value); + return PushBack(v, allocator); + } + + //! Remove the last element in the array. + /*! + \note Constant time complexity. + */ + GenericValue& PopBack() { + RAPIDJSON_ASSERT(IsArray()); + RAPIDJSON_ASSERT(!Empty()); + GetElementsPointer()[--data_.a.size].~GenericValue(); + return *this; + } + + //! Remove an element of array by iterator. + /*! + \param pos iterator to the element to remove + \pre IsArray() == true && \ref Begin() <= \c pos < \ref End() + \return Iterator following the removed element. If the iterator pos refers to the last element, the End() iterator is returned. + \note Linear time complexity. + */ + ValueIterator Erase(ConstValueIterator pos) { + return Erase(pos, pos + 1); + } + + //! Remove elements in the range [first, last) of the array. + /*! + \param first iterator to the first element to remove + \param last iterator following the last element to remove + \pre IsArray() == true && \ref Begin() <= \c first <= \c last <= \ref End() + \return Iterator following the last removed element. + \note Linear time complexity. + */ + ValueIterator Erase(ConstValueIterator first, ConstValueIterator last) { + RAPIDJSON_ASSERT(IsArray()); + RAPIDJSON_ASSERT(data_.a.size > 0); + RAPIDJSON_ASSERT(GetElementsPointer() != 0); + RAPIDJSON_ASSERT(first >= Begin()); + RAPIDJSON_ASSERT(first <= last); + RAPIDJSON_ASSERT(last <= End()); + ValueIterator pos = Begin() + (first - Begin()); + for (ValueIterator itr = pos; itr != last; ++itr) + itr->~GenericValue(); + std::memmove(static_cast(pos), last, static_cast(End() - last) * sizeof(GenericValue)); + data_.a.size -= static_cast(last - first); + return pos; + } + + Array GetArray() { RAPIDJSON_ASSERT(IsArray()); return Array(*this); } + ConstArray GetArray() const { RAPIDJSON_ASSERT(IsArray()); return ConstArray(*this); } + + //@} + + //!@name Number + //@{ + + int GetInt() const { RAPIDJSON_ASSERT(data_.f.flags & kIntFlag); return data_.n.i.i; } + unsigned GetUint() const { RAPIDJSON_ASSERT(data_.f.flags & kUintFlag); return data_.n.u.u; } + int64_t GetInt64() const { RAPIDJSON_ASSERT(data_.f.flags & kInt64Flag); return data_.n.i64; } + uint64_t GetUint64() const { RAPIDJSON_ASSERT(data_.f.flags & kUint64Flag); return data_.n.u64; } + + //! Get the value as double type. + /*! \note If the value is 64-bit integer type, it may lose precision. Use \c IsLosslessDouble() to check whether the converison is lossless. + */ + double GetDouble() const { + RAPIDJSON_ASSERT(IsNumber()); + if ((data_.f.flags & kDoubleFlag) != 0) return data_.n.d; // exact type, no conversion. + if ((data_.f.flags & kIntFlag) != 0) return data_.n.i.i; // int -> double + if ((data_.f.flags & kUintFlag) != 0) return data_.n.u.u; // unsigned -> double + if ((data_.f.flags & kInt64Flag) != 0) return static_cast(data_.n.i64); // int64_t -> double (may lose precision) + RAPIDJSON_ASSERT((data_.f.flags & kUint64Flag) != 0); return static_cast(data_.n.u64); // uint64_t -> double (may lose precision) + } + + //! Get the value as float type. + /*! \note If the value is 64-bit integer type, it may lose precision. Use \c IsLosslessFloat() to check whether the converison is lossless. + */ + float GetFloat() const { + return static_cast(GetDouble()); + } + + GenericValue& SetInt(int i) { this->~GenericValue(); new (this) GenericValue(i); return *this; } + GenericValue& SetUint(unsigned u) { this->~GenericValue(); new (this) GenericValue(u); return *this; } + GenericValue& SetInt64(int64_t i64) { this->~GenericValue(); new (this) GenericValue(i64); return *this; } + GenericValue& SetUint64(uint64_t u64) { this->~GenericValue(); new (this) GenericValue(u64); return *this; } + GenericValue& SetDouble(double d) { this->~GenericValue(); new (this) GenericValue(d); return *this; } + GenericValue& SetFloat(float f) { this->~GenericValue(); new (this) GenericValue(static_cast(f)); return *this; } + + //@} + + //!@name String + //@{ + + const Ch* GetString() const { RAPIDJSON_ASSERT(IsString()); return DataString(data_); } + + //! Get the length of string. + /*! Since rapidjson permits "\\u0000" in the json string, strlen(v.GetString()) may not equal to v.GetStringLength(). + */ + SizeType GetStringLength() const { RAPIDJSON_ASSERT(IsString()); return DataStringLength(data_); } + + //! Set this value as a string without copying source string. + /*! This version has better performance with supplied length, and also support string containing null character. + \param s source string pointer. + \param length The length of source string, excluding the trailing null terminator. + \return The value itself for fluent API. + \post IsString() == true && GetString() == s && GetStringLength() == length + \see SetString(StringRefType) + */ + GenericValue& SetString(const Ch* s, SizeType length) { return SetString(StringRef(s, length)); } + + //! Set this value as a string without copying source string. + /*! \param s source string reference + \return The value itself for fluent API. + \post IsString() == true && GetString() == s && GetStringLength() == s.length + */ + GenericValue& SetString(StringRefType s) { this->~GenericValue(); SetStringRaw(s); return *this; } + + //! Set this value as a string by copying from source string. + /*! This version has better performance with supplied length, and also support string containing null character. + \param s source string. + \param length The length of source string, excluding the trailing null terminator. + \param allocator Allocator for allocating copied buffer. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \post IsString() == true && GetString() != s && strcmp(GetString(),s) == 0 && GetStringLength() == length + */ + GenericValue& SetString(const Ch* s, SizeType length, Allocator& allocator) { return SetString(StringRef(s, length), allocator); } + + //! Set this value as a string by copying from source string. + /*! \param s source string. + \param allocator Allocator for allocating copied buffer. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \post IsString() == true && GetString() != s && strcmp(GetString(),s) == 0 && GetStringLength() == length + */ + GenericValue& SetString(const Ch* s, Allocator& allocator) { return SetString(StringRef(s), allocator); } + + //! Set this value as a string by copying from source string. + /*! \param s source string reference + \param allocator Allocator for allocating copied buffer. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \post IsString() == true && GetString() != s.s && strcmp(GetString(),s) == 0 && GetStringLength() == length + */ + GenericValue& SetString(StringRefType s, Allocator& allocator) { this->~GenericValue(); SetStringRaw(s, allocator); return *this; } + +#if RAPIDJSON_HAS_STDSTRING + //! Set this value as a string by copying from source string. + /*! \param s source string. + \param allocator Allocator for allocating copied buffer. Commonly use GenericDocument::GetAllocator(). + \return The value itself for fluent API. + \post IsString() == true && GetString() != s.data() && strcmp(GetString(),s.data() == 0 && GetStringLength() == s.size() + \note Requires the definition of the preprocessor symbol \ref RAPIDJSON_HAS_STDSTRING. + */ + GenericValue& SetString(const std::basic_string& s, Allocator& allocator) { return SetString(StringRef(s), allocator); } +#endif + + //@} + + //!@name Array + //@{ + + //! Templated version for checking whether this value is type T. + /*! + \tparam T Either \c bool, \c int, \c unsigned, \c int64_t, \c uint64_t, \c double, \c float, \c const \c char*, \c std::basic_string + */ + template + bool Is() const { return internal::TypeHelper::Is(*this); } + + template + T Get() const { return internal::TypeHelper::Get(*this); } + + template + T Get() { return internal::TypeHelper::Get(*this); } + + template + ValueType& Set(const T& data) { return internal::TypeHelper::Set(*this, data); } + + template + ValueType& Set(const T& data, AllocatorType& allocator) { return internal::TypeHelper::Set(*this, data, allocator); } + + //@} + + //! Generate events of this value to a Handler. + /*! This function adopts the GoF visitor pattern. + Typical usage is to output this JSON value as JSON text via Writer, which is a Handler. + It can also be used to deep clone this value via GenericDocument, which is also a Handler. + \tparam Handler type of handler. + \param handler An object implementing concept Handler. + */ + template + bool Accept(Handler& handler) const { + switch(GetType()) { + case kNullType: return handler.Null(); + case kFalseType: return handler.Bool(false); + case kTrueType: return handler.Bool(true); + + case kObjectType: + if (RAPIDJSON_UNLIKELY(!handler.StartObject())) + return false; + for (ConstMemberIterator m = MemberBegin(); m != MemberEnd(); ++m) { + RAPIDJSON_ASSERT(m->name.IsString()); // User may change the type of name by MemberIterator. + if (RAPIDJSON_UNLIKELY(!handler.Key(m->name.GetString(), m->name.GetStringLength(), (m->name.data_.f.flags & kCopyFlag) != 0))) + return false; + if (RAPIDJSON_UNLIKELY(!m->value.Accept(handler))) + return false; + } + return handler.EndObject(data_.o.size); + + case kArrayType: + if (RAPIDJSON_UNLIKELY(!handler.StartArray())) + return false; + for (ConstValueIterator v = Begin(); v != End(); ++v) + if (RAPIDJSON_UNLIKELY(!v->Accept(handler))) + return false; + return handler.EndArray(data_.a.size); + + case kStringType: + return handler.String(GetString(), GetStringLength(), (data_.f.flags & kCopyFlag) != 0); + + default: + RAPIDJSON_ASSERT(GetType() == kNumberType); + if (IsDouble()) return handler.Double(data_.n.d); + else if (IsInt()) return handler.Int(data_.n.i.i); + else if (IsUint()) return handler.Uint(data_.n.u.u); + else if (IsInt64()) return handler.Int64(data_.n.i64); + else return handler.Uint64(data_.n.u64); + } + } + +private: + template friend class GenericValue; + template friend class GenericDocument; + + enum { + kBoolFlag = 0x0008, + kNumberFlag = 0x0010, + kIntFlag = 0x0020, + kUintFlag = 0x0040, + kInt64Flag = 0x0080, + kUint64Flag = 0x0100, + kDoubleFlag = 0x0200, + kStringFlag = 0x0400, + kCopyFlag = 0x0800, + kInlineStrFlag = 0x1000, + + // Initial flags of different types. + kNullFlag = kNullType, + // These casts are added to suppress the warning on MSVC about bitwise operations between enums of different types. + kTrueFlag = static_cast(kTrueType) | static_cast(kBoolFlag), + kFalseFlag = static_cast(kFalseType) | static_cast(kBoolFlag), + kNumberIntFlag = static_cast(kNumberType) | static_cast(kNumberFlag | kIntFlag | kInt64Flag), + kNumberUintFlag = static_cast(kNumberType) | static_cast(kNumberFlag | kUintFlag | kUint64Flag | kInt64Flag), + kNumberInt64Flag = static_cast(kNumberType) | static_cast(kNumberFlag | kInt64Flag), + kNumberUint64Flag = static_cast(kNumberType) | static_cast(kNumberFlag | kUint64Flag), + kNumberDoubleFlag = static_cast(kNumberType) | static_cast(kNumberFlag | kDoubleFlag), + kNumberAnyFlag = static_cast(kNumberType) | static_cast(kNumberFlag | kIntFlag | kInt64Flag | kUintFlag | kUint64Flag | kDoubleFlag), + kConstStringFlag = static_cast(kStringType) | static_cast(kStringFlag), + kCopyStringFlag = static_cast(kStringType) | static_cast(kStringFlag | kCopyFlag), + kShortStringFlag = static_cast(kStringType) | static_cast(kStringFlag | kCopyFlag | kInlineStrFlag), + kObjectFlag = kObjectType, + kArrayFlag = kArrayType, + + kTypeMask = 0x07 + }; + + static const SizeType kDefaultArrayCapacity = RAPIDJSON_VALUE_DEFAULT_ARRAY_CAPACITY; + static const SizeType kDefaultObjectCapacity = RAPIDJSON_VALUE_DEFAULT_OBJECT_CAPACITY; + + struct Flag { +#if RAPIDJSON_48BITPOINTER_OPTIMIZATION + char payload[sizeof(SizeType) * 2 + 6]; // 2 x SizeType + lower 48-bit pointer +#elif RAPIDJSON_64BIT + char payload[sizeof(SizeType) * 2 + sizeof(void*) + 6]; // 6 padding bytes +#else + char payload[sizeof(SizeType) * 2 + sizeof(void*) + 2]; // 2 padding bytes +#endif + uint16_t flags; + }; + + struct String { + SizeType length; + SizeType hashcode; //!< reserved + const Ch* str; + }; // 12 bytes in 32-bit mode, 16 bytes in 64-bit mode + + // implementation detail: ShortString can represent zero-terminated strings up to MaxSize chars + // (excluding the terminating zero) and store a value to determine the length of the contained + // string in the last character str[LenPos] by storing "MaxSize - length" there. If the string + // to store has the maximal length of MaxSize then str[LenPos] will be 0 and therefore act as + // the string terminator as well. For getting the string length back from that value just use + // "MaxSize - str[LenPos]". + // This allows to store 13-chars strings in 32-bit mode, 21-chars strings in 64-bit mode, + // 13-chars strings for RAPIDJSON_48BITPOINTER_OPTIMIZATION=1 inline (for `UTF8`-encoded strings). + struct ShortString { + enum { MaxChars = sizeof(static_cast(0)->payload) / sizeof(Ch), MaxSize = MaxChars - 1, LenPos = MaxSize }; + Ch str[MaxChars]; + + inline static bool Usable(SizeType len) { return (MaxSize >= len); } + inline void SetLength(SizeType len) { str[LenPos] = static_cast(MaxSize - len); } + inline SizeType GetLength() const { return static_cast(MaxSize - str[LenPos]); } + }; // at most as many bytes as "String" above => 12 bytes in 32-bit mode, 16 bytes in 64-bit mode + + // By using proper binary layout, retrieval of different integer types do not need conversions. + union Number { +#if RAPIDJSON_ENDIAN == RAPIDJSON_LITTLEENDIAN + struct I { + int i; + char padding[4]; + }i; + struct U { + unsigned u; + char padding2[4]; + }u; +#else + struct I { + char padding[4]; + int i; + }i; + struct U { + char padding2[4]; + unsigned u; + }u; +#endif + int64_t i64; + uint64_t u64; + double d; + }; // 8 bytes + + struct ObjectData { + SizeType size; + SizeType capacity; + Member* members; + }; // 12 bytes in 32-bit mode, 16 bytes in 64-bit mode + + struct ArrayData { + SizeType size; + SizeType capacity; + GenericValue* elements; + }; // 12 bytes in 32-bit mode, 16 bytes in 64-bit mode + + union Data { + String s; + ShortString ss; + Number n; + ObjectData o; + ArrayData a; + Flag f; + }; // 16 bytes in 32-bit mode, 24 bytes in 64-bit mode, 16 bytes in 64-bit with RAPIDJSON_48BITPOINTER_OPTIMIZATION + + static RAPIDJSON_FORCEINLINE const Ch* DataString(const Data& data) { + return (data.f.flags & kInlineStrFlag) ? data.ss.str : RAPIDJSON_GETPOINTER(Ch, data.s.str); + } + static RAPIDJSON_FORCEINLINE SizeType DataStringLength(const Data& data) { + return (data.f.flags & kInlineStrFlag) ? data.ss.GetLength() : data.s.length; + } + + RAPIDJSON_FORCEINLINE const Ch* GetStringPointer() const { return RAPIDJSON_GETPOINTER(Ch, data_.s.str); } + RAPIDJSON_FORCEINLINE const Ch* SetStringPointer(const Ch* str) { return RAPIDJSON_SETPOINTER(Ch, data_.s.str, str); } + RAPIDJSON_FORCEINLINE GenericValue* GetElementsPointer() const { return RAPIDJSON_GETPOINTER(GenericValue, data_.a.elements); } + RAPIDJSON_FORCEINLINE GenericValue* SetElementsPointer(GenericValue* elements) { return RAPIDJSON_SETPOINTER(GenericValue, data_.a.elements, elements); } + RAPIDJSON_FORCEINLINE Member* GetMembersPointer() const { return RAPIDJSON_GETPOINTER(Member, data_.o.members); } + RAPIDJSON_FORCEINLINE Member* SetMembersPointer(Member* members) { return RAPIDJSON_SETPOINTER(Member, data_.o.members, members); } + +#if RAPIDJSON_USE_MEMBERSMAP + + struct MapTraits { + struct Less { + bool operator()(const Data& s1, const Data& s2) const { + SizeType n1 = DataStringLength(s1), n2 = DataStringLength(s2); + int cmp = std::memcmp(DataString(s1), DataString(s2), sizeof(Ch) * (n1 < n2 ? n1 : n2)); + return cmp < 0 || (cmp == 0 && n1 < n2); + } + }; + typedef std::pair Pair; + typedef std::multimap > Map; + typedef typename Map::iterator Iterator; + }; + typedef typename MapTraits::Map Map; + typedef typename MapTraits::Less MapLess; + typedef typename MapTraits::Pair MapPair; + typedef typename MapTraits::Iterator MapIterator; + + // + // Layout of the members' map/array, re(al)located according to the needed capacity: + // + // {Map*}<>{capacity}<>{Member[capacity]}<>{MapIterator[capacity]} + // + // (where <> stands for the RAPIDJSON_ALIGN-ment, if needed) + // + + static RAPIDJSON_FORCEINLINE size_t GetMapLayoutSize(SizeType capacity) { + return RAPIDJSON_ALIGN(sizeof(Map*)) + + RAPIDJSON_ALIGN(sizeof(SizeType)) + + RAPIDJSON_ALIGN(capacity * sizeof(Member)) + + capacity * sizeof(MapIterator); + } + + static RAPIDJSON_FORCEINLINE SizeType &GetMapCapacity(Map* &map) { + return *reinterpret_cast(reinterpret_cast(&map) + + RAPIDJSON_ALIGN(sizeof(Map*))); + } + + static RAPIDJSON_FORCEINLINE Member* GetMapMembers(Map* &map) { + return reinterpret_cast(reinterpret_cast(&map) + + RAPIDJSON_ALIGN(sizeof(Map*)) + + RAPIDJSON_ALIGN(sizeof(SizeType))); + } + + static RAPIDJSON_FORCEINLINE MapIterator* GetMapIterators(Map* &map) { + return reinterpret_cast(reinterpret_cast(&map) + + RAPIDJSON_ALIGN(sizeof(Map*)) + + RAPIDJSON_ALIGN(sizeof(SizeType)) + + RAPIDJSON_ALIGN(GetMapCapacity(map) * sizeof(Member))); + } + + static RAPIDJSON_FORCEINLINE Map* &GetMap(Member* members) { + RAPIDJSON_ASSERT(members != 0); + return *reinterpret_cast(reinterpret_cast(members) - + RAPIDJSON_ALIGN(sizeof(SizeType)) - + RAPIDJSON_ALIGN(sizeof(Map*))); + } + + // Some compilers' debug mechanisms want all iterators to be destroyed, for their accounting.. + RAPIDJSON_FORCEINLINE MapIterator DropMapIterator(MapIterator& rhs) { +#if RAPIDJSON_HAS_CXX11 + MapIterator ret = std::move(rhs); +#else + MapIterator ret = rhs; +#endif + rhs.~MapIterator(); + return ret; + } + + Map* &DoReallocMap(Map** oldMap, SizeType newCapacity, Allocator& allocator) { + Map **newMap = static_cast(allocator.Malloc(GetMapLayoutSize(newCapacity))); + GetMapCapacity(*newMap) = newCapacity; + if (!oldMap) { + *newMap = new (allocator.Malloc(sizeof(Map))) Map(MapLess(), allocator); + } + else { + *newMap = *oldMap; + size_t count = (*oldMap)->size(); + std::memcpy(static_cast(GetMapMembers(*newMap)), + static_cast(GetMapMembers(*oldMap)), + count * sizeof(Member)); + MapIterator *oldIt = GetMapIterators(*oldMap), + *newIt = GetMapIterators(*newMap); + while (count--) { + new (&newIt[count]) MapIterator(DropMapIterator(oldIt[count])); + } + Allocator::Free(oldMap); + } + return *newMap; + } + + RAPIDJSON_FORCEINLINE Member* DoAllocMembers(SizeType capacity, Allocator& allocator) { + return GetMapMembers(DoReallocMap(0, capacity, allocator)); + } + + void DoReserveMembers(SizeType newCapacity, Allocator& allocator) { + ObjectData& o = data_.o; + if (newCapacity > o.capacity) { + Member* oldMembers = GetMembersPointer(); + Map **oldMap = oldMembers ? &GetMap(oldMembers) : 0, + *&newMap = DoReallocMap(oldMap, newCapacity, allocator); + RAPIDJSON_SETPOINTER(Member, o.members, GetMapMembers(newMap)); + o.capacity = newCapacity; + } + } + + template + MemberIterator DoFindMember(const GenericValue& name) { + if (Member* members = GetMembersPointer()) { + Map* &map = GetMap(members); + MapIterator mit = map->find(reinterpret_cast(name.data_)); + if (mit != map->end()) { + return MemberIterator(&members[mit->second]); + } + } + return MemberEnd(); + } + + void DoClearMembers() { + if (Member* members = GetMembersPointer()) { + Map* &map = GetMap(members); + MapIterator* mit = GetMapIterators(map); + for (SizeType i = 0; i < data_.o.size; i++) { + map->erase(DropMapIterator(mit[i])); + members[i].~Member(); + } + data_.o.size = 0; + } + } + + void DoFreeMembers() { + if (Member* members = GetMembersPointer()) { + GetMap(members)->~Map(); + for (SizeType i = 0; i < data_.o.size; i++) { + members[i].~Member(); + } + if (Allocator::kNeedFree) { // Shortcut by Allocator's trait + Map** map = &GetMap(members); + Allocator::Free(*map); + Allocator::Free(map); + } + } + } + +#else // !RAPIDJSON_USE_MEMBERSMAP + + RAPIDJSON_FORCEINLINE Member* DoAllocMembers(SizeType capacity, Allocator& allocator) { + return Malloc(allocator, capacity); + } + + void DoReserveMembers(SizeType newCapacity, Allocator& allocator) { + ObjectData& o = data_.o; + if (newCapacity > o.capacity) { + Member* newMembers = Realloc(allocator, GetMembersPointer(), o.capacity, newCapacity); + RAPIDJSON_SETPOINTER(Member, o.members, newMembers); + o.capacity = newCapacity; + } + } + + template + MemberIterator DoFindMember(const GenericValue& name) { + MemberIterator member = MemberBegin(); + for ( ; member != MemberEnd(); ++member) + if (name.StringEqual(member->name)) + break; + return member; + } + + void DoClearMembers() { + for (MemberIterator m = MemberBegin(); m != MemberEnd(); ++m) + m->~Member(); + data_.o.size = 0; + } + + void DoFreeMembers() { + for (MemberIterator m = MemberBegin(); m != MemberEnd(); ++m) + m->~Member(); + Allocator::Free(GetMembersPointer()); + } + +#endif // !RAPIDJSON_USE_MEMBERSMAP + + void DoAddMember(GenericValue& name, GenericValue& value, Allocator& allocator) { + ObjectData& o = data_.o; + if (o.size >= o.capacity) + DoReserveMembers(o.capacity ? (o.capacity + (o.capacity + 1) / 2) : kDefaultObjectCapacity, allocator); + Member* members = GetMembersPointer(); + Member* m = members + o.size; + m->name.RawAssign(name); + m->value.RawAssign(value); +#if RAPIDJSON_USE_MEMBERSMAP + Map* &map = GetMap(members); + MapIterator* mit = GetMapIterators(map); + new (&mit[o.size]) MapIterator(map->insert(MapPair(m->name.data_, o.size))); +#endif + ++o.size; + } + + MemberIterator DoRemoveMember(MemberIterator m) { + ObjectData& o = data_.o; + Member* members = GetMembersPointer(); +#if RAPIDJSON_USE_MEMBERSMAP + Map* &map = GetMap(members); + MapIterator* mit = GetMapIterators(map); + SizeType mpos = static_cast(&*m - members); + map->erase(DropMapIterator(mit[mpos])); +#endif + MemberIterator last(members + (o.size - 1)); + if (o.size > 1 && m != last) { +#if RAPIDJSON_USE_MEMBERSMAP + new (&mit[mpos]) MapIterator(DropMapIterator(mit[&*last - members])); + mit[mpos]->second = mpos; +#endif + *m = *last; // Move the last one to this place + } + else { + m->~Member(); // Only one left, just destroy + } + --o.size; + return m; + } + + MemberIterator DoEraseMembers(ConstMemberIterator first, ConstMemberIterator last) { + ObjectData& o = data_.o; + MemberIterator beg = MemberBegin(), + pos = beg + (first - beg), + end = MemberEnd(); +#if RAPIDJSON_USE_MEMBERSMAP + Map* &map = GetMap(GetMembersPointer()); + MapIterator* mit = GetMapIterators(map); +#endif + for (MemberIterator itr = pos; itr != last; ++itr) { +#if RAPIDJSON_USE_MEMBERSMAP + map->erase(DropMapIterator(mit[itr - beg])); +#endif + itr->~Member(); + } +#if RAPIDJSON_USE_MEMBERSMAP + if (first != last) { + // Move remaining members/iterators + MemberIterator next = pos + (last - first); + for (MemberIterator itr = pos; next != end; ++itr, ++next) { + std::memcpy(static_cast(&*itr), &*next, sizeof(Member)); + SizeType mpos = static_cast(itr - beg); + new (&mit[mpos]) MapIterator(DropMapIterator(mit[next - beg])); + mit[mpos]->second = mpos; + } + } +#else + std::memmove(static_cast(&*pos), &*last, + static_cast(end - last) * sizeof(Member)); +#endif + o.size -= static_cast(last - first); + return pos; + } + + template + void DoCopyMembers(const GenericValue& rhs, Allocator& allocator, bool copyConstStrings) { + RAPIDJSON_ASSERT(rhs.GetType() == kObjectType); + + data_.f.flags = kObjectFlag; + SizeType count = rhs.data_.o.size; + Member* lm = DoAllocMembers(count, allocator); + const typename GenericValue::Member* rm = rhs.GetMembersPointer(); +#if RAPIDJSON_USE_MEMBERSMAP + Map* &map = GetMap(lm); + MapIterator* mit = GetMapIterators(map); +#endif + for (SizeType i = 0; i < count; i++) { + new (&lm[i].name) GenericValue(rm[i].name, allocator, copyConstStrings); + new (&lm[i].value) GenericValue(rm[i].value, allocator, copyConstStrings); +#if RAPIDJSON_USE_MEMBERSMAP + new (&mit[i]) MapIterator(map->insert(MapPair(lm[i].name.data_, i))); +#endif + } + data_.o.size = data_.o.capacity = count; + SetMembersPointer(lm); + } + + // Initialize this value as array with initial data, without calling destructor. + void SetArrayRaw(GenericValue* values, SizeType count, Allocator& allocator) { + data_.f.flags = kArrayFlag; + if (count) { + GenericValue* e = static_cast(allocator.Malloc(count * sizeof(GenericValue))); + SetElementsPointer(e); + std::memcpy(static_cast(e), values, count * sizeof(GenericValue)); + } + else + SetElementsPointer(0); + data_.a.size = data_.a.capacity = count; + } + + //! Initialize this value as object with initial data, without calling destructor. + void SetObjectRaw(Member* members, SizeType count, Allocator& allocator) { + data_.f.flags = kObjectFlag; + if (count) { + Member* m = DoAllocMembers(count, allocator); + SetMembersPointer(m); + std::memcpy(static_cast(m), members, count * sizeof(Member)); +#if RAPIDJSON_USE_MEMBERSMAP + Map* &map = GetMap(m); + MapIterator* mit = GetMapIterators(map); + for (SizeType i = 0; i < count; i++) { + new (&mit[i]) MapIterator(map->insert(MapPair(m[i].name.data_, i))); + } +#endif + } + else + SetMembersPointer(0); + data_.o.size = data_.o.capacity = count; + } + + //! Initialize this value as constant string, without calling destructor. + void SetStringRaw(StringRefType s) RAPIDJSON_NOEXCEPT { + data_.f.flags = kConstStringFlag; + SetStringPointer(s); + data_.s.length = s.length; + } + + //! Initialize this value as copy string with initial data, without calling destructor. + void SetStringRaw(StringRefType s, Allocator& allocator) { + Ch* str = 0; + if (ShortString::Usable(s.length)) { + data_.f.flags = kShortStringFlag; + data_.ss.SetLength(s.length); + str = data_.ss.str; + std::memmove(str, s, s.length * sizeof(Ch)); + } else { + data_.f.flags = kCopyStringFlag; + data_.s.length = s.length; + str = static_cast(allocator.Malloc((s.length + 1) * sizeof(Ch))); + SetStringPointer(str); + std::memcpy(str, s, s.length * sizeof(Ch)); + } + str[s.length] = '\0'; + } + + //! Assignment without calling destructor + void RawAssign(GenericValue& rhs) RAPIDJSON_NOEXCEPT { + data_ = rhs.data_; + // data_.f.flags = rhs.data_.f.flags; + rhs.data_.f.flags = kNullFlag; + } + + template + bool StringEqual(const GenericValue& rhs) const { + RAPIDJSON_ASSERT(IsString()); + RAPIDJSON_ASSERT(rhs.IsString()); + + const SizeType len1 = GetStringLength(); + const SizeType len2 = rhs.GetStringLength(); + if(len1 != len2) { return false; } + + const Ch* const str1 = GetString(); + const Ch* const str2 = rhs.GetString(); + if(str1 == str2) { return true; } // fast path for constant string + + return (std::memcmp(str1, str2, sizeof(Ch) * len1) == 0); + } + + Data data_; +}; + +//! GenericValue with UTF8 encoding +typedef GenericValue > Value; + +/////////////////////////////////////////////////////////////////////////////// +// GenericDocument + +//! A document for parsing JSON text as DOM. +/*! + \note implements Handler concept + \tparam Encoding Encoding for both parsing and string storage. + \tparam Allocator Allocator for allocating memory for the DOM + \tparam StackAllocator Allocator for allocating memory for stack during parsing. + \warning Although GenericDocument inherits from GenericValue, the API does \b not provide any virtual functions, especially no virtual destructor. To avoid memory leaks, do not \c delete a GenericDocument object via a pointer to a GenericValue. +*/ +template +class GenericDocument : public GenericValue { +public: + typedef typename Encoding::Ch Ch; //!< Character type derived from Encoding. + typedef GenericValue ValueType; //!< Value type of the document. + typedef Allocator AllocatorType; //!< Allocator type from template parameter. + typedef StackAllocator StackAllocatorType; //!< StackAllocator type from template parameter. + + //! Constructor + /*! Creates an empty document of specified type. + \param type Mandatory type of object to create. + \param allocator Optional allocator for allocating memory. + \param stackCapacity Optional initial capacity of stack in bytes. + \param stackAllocator Optional allocator for allocating memory for stack. + */ + explicit GenericDocument(Type type, Allocator* allocator = 0, size_t stackCapacity = kDefaultStackCapacity, StackAllocator* stackAllocator = 0) : + GenericValue(type), allocator_(allocator), ownAllocator_(0), stack_(stackAllocator, stackCapacity), parseResult_() + { + if (!allocator_) + ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)(); + } + + //! Constructor + /*! Creates an empty document which type is Null. + \param allocator Optional allocator for allocating memory. + \param stackCapacity Optional initial capacity of stack in bytes. + \param stackAllocator Optional allocator for allocating memory for stack. + */ + GenericDocument(Allocator* allocator = 0, size_t stackCapacity = kDefaultStackCapacity, StackAllocator* stackAllocator = 0) : + allocator_(allocator), ownAllocator_(0), stack_(stackAllocator, stackCapacity), parseResult_() + { + if (!allocator_) + ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)(); + } + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + //! Move constructor in C++11 + GenericDocument(GenericDocument&& rhs) RAPIDJSON_NOEXCEPT + : ValueType(std::forward(rhs)), // explicit cast to avoid prohibited move from Document + allocator_(rhs.allocator_), + ownAllocator_(rhs.ownAllocator_), + stack_(std::move(rhs.stack_)), + parseResult_(rhs.parseResult_) + { + rhs.allocator_ = 0; + rhs.ownAllocator_ = 0; + rhs.parseResult_ = ParseResult(); + } +#endif + + ~GenericDocument() { + // Clear the ::ValueType before ownAllocator is destroyed, ~ValueType() + // runs last and may access its elements or members which would be freed + // with an allocator like MemoryPoolAllocator (CrtAllocator does not + // free its data when destroyed, but MemoryPoolAllocator does). + if (ownAllocator_) { + ValueType::SetNull(); + } + Destroy(); + } + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + //! Move assignment in C++11 + GenericDocument& operator=(GenericDocument&& rhs) RAPIDJSON_NOEXCEPT + { + // The cast to ValueType is necessary here, because otherwise it would + // attempt to call GenericValue's templated assignment operator. + ValueType::operator=(std::forward(rhs)); + + // Calling the destructor here would prematurely call stack_'s destructor + Destroy(); + + allocator_ = rhs.allocator_; + ownAllocator_ = rhs.ownAllocator_; + stack_ = std::move(rhs.stack_); + parseResult_ = rhs.parseResult_; + + rhs.allocator_ = 0; + rhs.ownAllocator_ = 0; + rhs.parseResult_ = ParseResult(); + + return *this; + } +#endif + + //! Exchange the contents of this document with those of another. + /*! + \param rhs Another document. + \note Constant complexity. + \see GenericValue::Swap + */ + GenericDocument& Swap(GenericDocument& rhs) RAPIDJSON_NOEXCEPT { + ValueType::Swap(rhs); + stack_.Swap(rhs.stack_); + internal::Swap(allocator_, rhs.allocator_); + internal::Swap(ownAllocator_, rhs.ownAllocator_); + internal::Swap(parseResult_, rhs.parseResult_); + return *this; + } + + // Allow Swap with ValueType. + // Refer to Effective C++ 3rd Edition/Item 33: Avoid hiding inherited names. + using ValueType::Swap; + + //! free-standing swap function helper + /*! + Helper function to enable support for common swap implementation pattern based on \c std::swap: + \code + void swap(MyClass& a, MyClass& b) { + using std::swap; + swap(a.doc, b.doc); + // ... + } + \endcode + \see Swap() + */ + friend inline void swap(GenericDocument& a, GenericDocument& b) RAPIDJSON_NOEXCEPT { a.Swap(b); } + + //! Populate this document by a generator which produces SAX events. + /*! \tparam Generator A functor with bool f(Handler) prototype. + \param g Generator functor which sends SAX events to the parameter. + \return The document itself for fluent API. + */ + template + GenericDocument& Populate(Generator& g) { + ClearStackOnExit scope(*this); + if (g(*this)) { + RAPIDJSON_ASSERT(stack_.GetSize() == sizeof(ValueType)); // Got one and only one root object + ValueType::operator=(*stack_.template Pop(1));// Move value from stack to document + } + return *this; + } + + //!@name Parse from stream + //!@{ + + //! Parse JSON text from an input stream (with Encoding conversion) + /*! \tparam parseFlags Combination of \ref ParseFlag. + \tparam SourceEncoding Encoding of input stream + \tparam InputStream Type of input stream, implementing Stream concept + \param is Input stream to be parsed. + \return The document itself for fluent API. + */ + template + GenericDocument& ParseStream(InputStream& is) { + GenericReader reader( + stack_.HasAllocator() ? &stack_.GetAllocator() : 0); + ClearStackOnExit scope(*this); + parseResult_ = reader.template Parse(is, *this); + if (parseResult_) { + RAPIDJSON_ASSERT(stack_.GetSize() == sizeof(ValueType)); // Got one and only one root object + ValueType::operator=(*stack_.template Pop(1));// Move value from stack to document + } + return *this; + } + + //! Parse JSON text from an input stream + /*! \tparam parseFlags Combination of \ref ParseFlag. + \tparam InputStream Type of input stream, implementing Stream concept + \param is Input stream to be parsed. + \return The document itself for fluent API. + */ + template + GenericDocument& ParseStream(InputStream& is) { + return ParseStream(is); + } + + //! Parse JSON text from an input stream (with \ref kParseDefaultFlags) + /*! \tparam InputStream Type of input stream, implementing Stream concept + \param is Input stream to be parsed. + \return The document itself for fluent API. + */ + template + GenericDocument& ParseStream(InputStream& is) { + return ParseStream(is); + } + //!@} + + //!@name Parse in-place from mutable string + //!@{ + + //! Parse JSON text from a mutable string + /*! \tparam parseFlags Combination of \ref ParseFlag. + \param str Mutable zero-terminated string to be parsed. + \return The document itself for fluent API. + */ + template + GenericDocument& ParseInsitu(Ch* str) { + GenericInsituStringStream s(str); + return ParseStream(s); + } + + //! Parse JSON text from a mutable string (with \ref kParseDefaultFlags) + /*! \param str Mutable zero-terminated string to be parsed. + \return The document itself for fluent API. + */ + GenericDocument& ParseInsitu(Ch* str) { + return ParseInsitu(str); + } + //!@} + + //!@name Parse from read-only string + //!@{ + + //! Parse JSON text from a read-only string (with Encoding conversion) + /*! \tparam parseFlags Combination of \ref ParseFlag (must not contain \ref kParseInsituFlag). + \tparam SourceEncoding Transcoding from input Encoding + \param str Read-only zero-terminated string to be parsed. + */ + template + GenericDocument& Parse(const typename SourceEncoding::Ch* str) { + RAPIDJSON_ASSERT(!(parseFlags & kParseInsituFlag)); + GenericStringStream s(str); + return ParseStream(s); + } + + //! Parse JSON text from a read-only string + /*! \tparam parseFlags Combination of \ref ParseFlag (must not contain \ref kParseInsituFlag). + \param str Read-only zero-terminated string to be parsed. + */ + template + GenericDocument& Parse(const Ch* str) { + return Parse(str); + } + + //! Parse JSON text from a read-only string (with \ref kParseDefaultFlags) + /*! \param str Read-only zero-terminated string to be parsed. + */ + GenericDocument& Parse(const Ch* str) { + return Parse(str); + } + + template + GenericDocument& Parse(const typename SourceEncoding::Ch* str, size_t length) { + RAPIDJSON_ASSERT(!(parseFlags & kParseInsituFlag)); + MemoryStream ms(reinterpret_cast(str), length * sizeof(typename SourceEncoding::Ch)); + EncodedInputStream is(ms); + ParseStream(is); + return *this; + } + + template + GenericDocument& Parse(const Ch* str, size_t length) { + return Parse(str, length); + } + + GenericDocument& Parse(const Ch* str, size_t length) { + return Parse(str, length); + } + +#if RAPIDJSON_HAS_STDSTRING + template + GenericDocument& Parse(const std::basic_string& str) { + // c_str() is constant complexity according to standard. Should be faster than Parse(const char*, size_t) + return Parse(str.c_str()); + } + + template + GenericDocument& Parse(const std::basic_string& str) { + return Parse(str.c_str()); + } + + GenericDocument& Parse(const std::basic_string& str) { + return Parse(str); + } +#endif // RAPIDJSON_HAS_STDSTRING + + //!@} + + //!@name Handling parse errors + //!@{ + + //! Whether a parse error has occurred in the last parsing. + bool HasParseError() const { return parseResult_.IsError(); } + + //! Get the \ref ParseErrorCode of last parsing. + ParseErrorCode GetParseError() const { return parseResult_.Code(); } + + //! Get the position of last parsing error in input, 0 otherwise. + size_t GetErrorOffset() const { return parseResult_.Offset(); } + + //! Implicit conversion to get the last parse result +#ifndef __clang // -Wdocumentation + /*! \return \ref ParseResult of the last parse operation + + \code + Document doc; + ParseResult ok = doc.Parse(json); + if (!ok) + printf( "JSON parse error: %s (%u)\n", GetParseError_En(ok.Code()), ok.Offset()); + \endcode + */ +#endif + operator ParseResult() const { return parseResult_; } + //!@} + + //! Get the allocator of this document. + Allocator& GetAllocator() { + RAPIDJSON_ASSERT(allocator_); + return *allocator_; + } + + //! Get the capacity of stack in bytes. + size_t GetStackCapacity() const { return stack_.GetCapacity(); } + +private: + // clear stack on any exit from ParseStream, e.g. due to exception + struct ClearStackOnExit { + explicit ClearStackOnExit(GenericDocument& d) : d_(d) {} + ~ClearStackOnExit() { d_.ClearStack(); } + private: + ClearStackOnExit(const ClearStackOnExit&); + ClearStackOnExit& operator=(const ClearStackOnExit&); + GenericDocument& d_; + }; + + // callers of the following private Handler functions + // template friend class GenericReader; // for parsing + template friend class GenericValue; // for deep copying + +public: + // Implementation of Handler + bool Null() { new (stack_.template Push()) ValueType(); return true; } + bool Bool(bool b) { new (stack_.template Push()) ValueType(b); return true; } + bool Int(int i) { new (stack_.template Push()) ValueType(i); return true; } + bool Uint(unsigned i) { new (stack_.template Push()) ValueType(i); return true; } + bool Int64(int64_t i) { new (stack_.template Push()) ValueType(i); return true; } + bool Uint64(uint64_t i) { new (stack_.template Push()) ValueType(i); return true; } + bool Double(double d) { new (stack_.template Push()) ValueType(d); return true; } + + bool RawNumber(const Ch* str, SizeType length, bool copy) { + if (copy) + new (stack_.template Push()) ValueType(str, length, GetAllocator()); + else + new (stack_.template Push()) ValueType(str, length); + return true; + } + + bool String(const Ch* str, SizeType length, bool copy) { + if (copy) + new (stack_.template Push()) ValueType(str, length, GetAllocator()); + else + new (stack_.template Push()) ValueType(str, length); + return true; + } + + bool StartObject() { new (stack_.template Push()) ValueType(kObjectType); return true; } + + bool Key(const Ch* str, SizeType length, bool copy) { return String(str, length, copy); } + + bool EndObject(SizeType memberCount) { + typename ValueType::Member* members = stack_.template Pop(memberCount); + stack_.template Top()->SetObjectRaw(members, memberCount, GetAllocator()); + return true; + } + + bool StartArray() { new (stack_.template Push()) ValueType(kArrayType); return true; } + + bool EndArray(SizeType elementCount) { + ValueType* elements = stack_.template Pop(elementCount); + stack_.template Top()->SetArrayRaw(elements, elementCount, GetAllocator()); + return true; + } + +private: + //! Prohibit copying + GenericDocument(const GenericDocument&); + //! Prohibit assignment + GenericDocument& operator=(const GenericDocument&); + + void ClearStack() { + if (Allocator::kNeedFree) + while (stack_.GetSize() > 0) // Here assumes all elements in stack array are GenericValue (Member is actually 2 GenericValue objects) + (stack_.template Pop(1))->~ValueType(); + else + stack_.Clear(); + stack_.ShrinkToFit(); + } + + void Destroy() { + RAPIDJSON_DELETE(ownAllocator_); + } + + static const size_t kDefaultStackCapacity = 1024; + Allocator* allocator_; + Allocator* ownAllocator_; + internal::Stack stack_; + ParseResult parseResult_; +}; + +//! GenericDocument with UTF8 encoding +typedef GenericDocument > Document; + + +//! Helper class for accessing Value of array type. +/*! + Instance of this helper class is obtained by \c GenericValue::GetArray(). + In addition to all APIs for array type, it provides range-based for loop if \c RAPIDJSON_HAS_CXX11_RANGE_FOR=1. +*/ +template +class GenericArray { +public: + typedef GenericArray ConstArray; + typedef GenericArray Array; + typedef ValueT PlainType; + typedef typename internal::MaybeAddConst::Type ValueType; + typedef ValueType* ValueIterator; // This may be const or non-const iterator + typedef const ValueT* ConstValueIterator; + typedef typename ValueType::AllocatorType AllocatorType; + typedef typename ValueType::StringRefType StringRefType; + + template + friend class GenericValue; + + GenericArray(const GenericArray& rhs) : value_(rhs.value_) {} + GenericArray& operator=(const GenericArray& rhs) { value_ = rhs.value_; return *this; } + ~GenericArray() {} + + operator ValueType&() const { return value_; } + SizeType Size() const { return value_.Size(); } + SizeType Capacity() const { return value_.Capacity(); } + bool Empty() const { return value_.Empty(); } + void Clear() const { value_.Clear(); } + ValueType& operator[](SizeType index) const { return value_[index]; } + ValueIterator Begin() const { return value_.Begin(); } + ValueIterator End() const { return value_.End(); } + GenericArray Reserve(SizeType newCapacity, AllocatorType &allocator) const { value_.Reserve(newCapacity, allocator); return *this; } + GenericArray PushBack(ValueType& value, AllocatorType& allocator) const { value_.PushBack(value, allocator); return *this; } +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + GenericArray PushBack(ValueType&& value, AllocatorType& allocator) const { value_.PushBack(value, allocator); return *this; } +#endif // RAPIDJSON_HAS_CXX11_RVALUE_REFS + GenericArray PushBack(StringRefType value, AllocatorType& allocator) const { value_.PushBack(value, allocator); return *this; } + template RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (const GenericArray&)) PushBack(T value, AllocatorType& allocator) const { value_.PushBack(value, allocator); return *this; } + GenericArray PopBack() const { value_.PopBack(); return *this; } + ValueIterator Erase(ConstValueIterator pos) const { return value_.Erase(pos); } + ValueIterator Erase(ConstValueIterator first, ConstValueIterator last) const { return value_.Erase(first, last); } + +#if RAPIDJSON_HAS_CXX11_RANGE_FOR + ValueIterator begin() const { return value_.Begin(); } + ValueIterator end() const { return value_.End(); } +#endif + +private: + GenericArray(); + GenericArray(ValueType& value) : value_(value) {} + ValueType& value_; +}; + +//! Helper class for accessing Value of object type. +/*! + Instance of this helper class is obtained by \c GenericValue::GetObject(). + In addition to all APIs for array type, it provides range-based for loop if \c RAPIDJSON_HAS_CXX11_RANGE_FOR=1. +*/ +template +class GenericObject { +public: + typedef GenericObject ConstObject; + typedef GenericObject Object; + typedef ValueT PlainType; + typedef typename internal::MaybeAddConst::Type ValueType; + typedef GenericMemberIterator MemberIterator; // This may be const or non-const iterator + typedef GenericMemberIterator ConstMemberIterator; + typedef typename ValueType::AllocatorType AllocatorType; + typedef typename ValueType::StringRefType StringRefType; + typedef typename ValueType::EncodingType EncodingType; + typedef typename ValueType::Ch Ch; + + template + friend class GenericValue; + + GenericObject(const GenericObject& rhs) : value_(rhs.value_) {} + GenericObject& operator=(const GenericObject& rhs) { value_ = rhs.value_; return *this; } + ~GenericObject() {} + + operator ValueType&() const { return value_; } + SizeType MemberCount() const { return value_.MemberCount(); } + SizeType MemberCapacity() const { return value_.MemberCapacity(); } + bool ObjectEmpty() const { return value_.ObjectEmpty(); } + template ValueType& operator[](T* name) const { return value_[name]; } + template ValueType& operator[](const GenericValue& name) const { return value_[name]; } +#if RAPIDJSON_HAS_STDSTRING + ValueType& operator[](const std::basic_string& name) const { return value_[name]; } +#endif + MemberIterator MemberBegin() const { return value_.MemberBegin(); } + MemberIterator MemberEnd() const { return value_.MemberEnd(); } + GenericObject MemberReserve(SizeType newCapacity, AllocatorType &allocator) const { value_.MemberReserve(newCapacity, allocator); return *this; } + bool HasMember(const Ch* name) const { return value_.HasMember(name); } +#if RAPIDJSON_HAS_STDSTRING + bool HasMember(const std::basic_string& name) const { return value_.HasMember(name); } +#endif + template bool HasMember(const GenericValue& name) const { return value_.HasMember(name); } + MemberIterator FindMember(const Ch* name) const { return value_.FindMember(name); } + template MemberIterator FindMember(const GenericValue& name) const { return value_.FindMember(name); } +#if RAPIDJSON_HAS_STDSTRING + MemberIterator FindMember(const std::basic_string& name) const { return value_.FindMember(name); } +#endif + GenericObject AddMember(ValueType& name, ValueType& value, AllocatorType& allocator) const { value_.AddMember(name, value, allocator); return *this; } + GenericObject AddMember(ValueType& name, StringRefType value, AllocatorType& allocator) const { value_.AddMember(name, value, allocator); return *this; } +#if RAPIDJSON_HAS_STDSTRING + GenericObject AddMember(ValueType& name, std::basic_string& value, AllocatorType& allocator) const { value_.AddMember(name, value, allocator); return *this; } +#endif + template RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (ValueType&)) AddMember(ValueType& name, T value, AllocatorType& allocator) const { value_.AddMember(name, value, allocator); return *this; } +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + GenericObject AddMember(ValueType&& name, ValueType&& value, AllocatorType& allocator) const { value_.AddMember(name, value, allocator); return *this; } + GenericObject AddMember(ValueType&& name, ValueType& value, AllocatorType& allocator) const { value_.AddMember(name, value, allocator); return *this; } + GenericObject AddMember(ValueType& name, ValueType&& value, AllocatorType& allocator) const { value_.AddMember(name, value, allocator); return *this; } + GenericObject AddMember(StringRefType name, ValueType&& value, AllocatorType& allocator) const { value_.AddMember(name, value, allocator); return *this; } +#endif // RAPIDJSON_HAS_CXX11_RVALUE_REFS + GenericObject AddMember(StringRefType name, ValueType& value, AllocatorType& allocator) const { value_.AddMember(name, value, allocator); return *this; } + GenericObject AddMember(StringRefType name, StringRefType value, AllocatorType& allocator) const { value_.AddMember(name, value, allocator); return *this; } + template RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (GenericObject)) AddMember(StringRefType name, T value, AllocatorType& allocator) const { value_.AddMember(name, value, allocator); return *this; } + void RemoveAllMembers() { value_.RemoveAllMembers(); } + bool RemoveMember(const Ch* name) const { return value_.RemoveMember(name); } +#if RAPIDJSON_HAS_STDSTRING + bool RemoveMember(const std::basic_string& name) const { return value_.RemoveMember(name); } +#endif + template bool RemoveMember(const GenericValue& name) const { return value_.RemoveMember(name); } + MemberIterator RemoveMember(MemberIterator m) const { return value_.RemoveMember(m); } + MemberIterator EraseMember(ConstMemberIterator pos) const { return value_.EraseMember(pos); } + MemberIterator EraseMember(ConstMemberIterator first, ConstMemberIterator last) const { return value_.EraseMember(first, last); } + bool EraseMember(const Ch* name) const { return value_.EraseMember(name); } +#if RAPIDJSON_HAS_STDSTRING + bool EraseMember(const std::basic_string& name) const { return EraseMember(ValueType(StringRef(name))); } +#endif + template bool EraseMember(const GenericValue& name) const { return value_.EraseMember(name); } + +#if RAPIDJSON_HAS_CXX11_RANGE_FOR + MemberIterator begin() const { return value_.MemberBegin(); } + MemberIterator end() const { return value_.MemberEnd(); } +#endif + +private: + GenericObject(); + GenericObject(ValueType& value) : value_(value) {} + ValueType& value_; +}; + +RAPIDJSON_NAMESPACE_END +RAPIDJSON_DIAG_POP + +#ifdef RAPIDJSON_WINDOWS_GETOBJECT_WORKAROUND_APPLIED +#pragma pop_macro("GetObject") +#undef RAPIDJSON_WINDOWS_GETOBJECT_WORKAROUND_APPLIED +#endif + +#endif // RAPIDJSON_DOCUMENT_H_ diff --git a/include/rapidjson/encodedstream.h b/include/rapidjson/encodedstream.h new file mode 100644 index 0000000000..cf046b8923 --- /dev/null +++ b/include/rapidjson/encodedstream.h @@ -0,0 +1,299 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_ENCODEDSTREAM_H_ +#define RAPIDJSON_ENCODEDSTREAM_H_ + +#include "stream.h" +#include "memorystream.h" + +#ifdef __GNUC__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(effc++) +#endif + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(padded) +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +//! Input byte stream wrapper with a statically bound encoding. +/*! + \tparam Encoding The interpretation of encoding of the stream. Either UTF8, UTF16LE, UTF16BE, UTF32LE, UTF32BE. + \tparam InputByteStream Type of input byte stream. For example, FileReadStream. +*/ +template +class EncodedInputStream { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); +public: + typedef typename Encoding::Ch Ch; + + EncodedInputStream(InputByteStream& is) : is_(is) { + current_ = Encoding::TakeBOM(is_); + } + + Ch Peek() const { return current_; } + Ch Take() { Ch c = current_; current_ = Encoding::Take(is_); return c; } + size_t Tell() const { return is_.Tell(); } + + // Not implemented + void Put(Ch) { RAPIDJSON_ASSERT(false); } + void Flush() { RAPIDJSON_ASSERT(false); } + Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; } + size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; } + +private: + EncodedInputStream(const EncodedInputStream&); + EncodedInputStream& operator=(const EncodedInputStream&); + + InputByteStream& is_; + Ch current_; +}; + +//! Specialized for UTF8 MemoryStream. +template <> +class EncodedInputStream, MemoryStream> { +public: + typedef UTF8<>::Ch Ch; + + EncodedInputStream(MemoryStream& is) : is_(is) { + if (static_cast(is_.Peek()) == 0xEFu) is_.Take(); + if (static_cast(is_.Peek()) == 0xBBu) is_.Take(); + if (static_cast(is_.Peek()) == 0xBFu) is_.Take(); + } + Ch Peek() const { return is_.Peek(); } + Ch Take() { return is_.Take(); } + size_t Tell() const { return is_.Tell(); } + + // Not implemented + void Put(Ch) {} + void Flush() {} + Ch* PutBegin() { return 0; } + size_t PutEnd(Ch*) { return 0; } + + MemoryStream& is_; + +private: + EncodedInputStream(const EncodedInputStream&); + EncodedInputStream& operator=(const EncodedInputStream&); +}; + +//! Output byte stream wrapper with statically bound encoding. +/*! + \tparam Encoding The interpretation of encoding of the stream. Either UTF8, UTF16LE, UTF16BE, UTF32LE, UTF32BE. + \tparam OutputByteStream Type of input byte stream. For example, FileWriteStream. +*/ +template +class EncodedOutputStream { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); +public: + typedef typename Encoding::Ch Ch; + + EncodedOutputStream(OutputByteStream& os, bool putBOM = true) : os_(os) { + if (putBOM) + Encoding::PutBOM(os_); + } + + void Put(Ch c) { Encoding::Put(os_, c); } + void Flush() { os_.Flush(); } + + // Not implemented + Ch Peek() const { RAPIDJSON_ASSERT(false); return 0;} + Ch Take() { RAPIDJSON_ASSERT(false); return 0;} + size_t Tell() const { RAPIDJSON_ASSERT(false); return 0; } + Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; } + size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; } + +private: + EncodedOutputStream(const EncodedOutputStream&); + EncodedOutputStream& operator=(const EncodedOutputStream&); + + OutputByteStream& os_; +}; + +#define RAPIDJSON_ENCODINGS_FUNC(x) UTF8::x, UTF16LE::x, UTF16BE::x, UTF32LE::x, UTF32BE::x + +//! Input stream wrapper with dynamically bound encoding and automatic encoding detection. +/*! + \tparam CharType Type of character for reading. + \tparam InputByteStream type of input byte stream to be wrapped. +*/ +template +class AutoUTFInputStream { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); +public: + typedef CharType Ch; + + //! Constructor. + /*! + \param is input stream to be wrapped. + \param type UTF encoding type if it is not detected from the stream. + */ + AutoUTFInputStream(InputByteStream& is, UTFType type = kUTF8) : is_(&is), type_(type), hasBOM_(false) { + RAPIDJSON_ASSERT(type >= kUTF8 && type <= kUTF32BE); + DetectType(); + static const TakeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Take) }; + takeFunc_ = f[type_]; + current_ = takeFunc_(*is_); + } + + UTFType GetType() const { return type_; } + bool HasBOM() const { return hasBOM_; } + + Ch Peek() const { return current_; } + Ch Take() { Ch c = current_; current_ = takeFunc_(*is_); return c; } + size_t Tell() const { return is_->Tell(); } + + // Not implemented + void Put(Ch) { RAPIDJSON_ASSERT(false); } + void Flush() { RAPIDJSON_ASSERT(false); } + Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; } + size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; } + +private: + AutoUTFInputStream(const AutoUTFInputStream&); + AutoUTFInputStream& operator=(const AutoUTFInputStream&); + + // Detect encoding type with BOM or RFC 4627 + void DetectType() { + // BOM (Byte Order Mark): + // 00 00 FE FF UTF-32BE + // FF FE 00 00 UTF-32LE + // FE FF UTF-16BE + // FF FE UTF-16LE + // EF BB BF UTF-8 + + const unsigned char* c = reinterpret_cast(is_->Peek4()); + if (!c) + return; + + unsigned bom = static_cast(c[0] | (c[1] << 8) | (c[2] << 16) | (c[3] << 24)); + hasBOM_ = false; + if (bom == 0xFFFE0000) { type_ = kUTF32BE; hasBOM_ = true; is_->Take(); is_->Take(); is_->Take(); is_->Take(); } + else if (bom == 0x0000FEFF) { type_ = kUTF32LE; hasBOM_ = true; is_->Take(); is_->Take(); is_->Take(); is_->Take(); } + else if ((bom & 0xFFFF) == 0xFFFE) { type_ = kUTF16BE; hasBOM_ = true; is_->Take(); is_->Take(); } + else if ((bom & 0xFFFF) == 0xFEFF) { type_ = kUTF16LE; hasBOM_ = true; is_->Take(); is_->Take(); } + else if ((bom & 0xFFFFFF) == 0xBFBBEF) { type_ = kUTF8; hasBOM_ = true; is_->Take(); is_->Take(); is_->Take(); } + + // RFC 4627: Section 3 + // "Since the first two characters of a JSON text will always be ASCII + // characters [RFC0020], it is possible to determine whether an octet + // stream is UTF-8, UTF-16 (BE or LE), or UTF-32 (BE or LE) by looking + // at the pattern of nulls in the first four octets." + // 00 00 00 xx UTF-32BE + // 00 xx 00 xx UTF-16BE + // xx 00 00 00 UTF-32LE + // xx 00 xx 00 UTF-16LE + // xx xx xx xx UTF-8 + + if (!hasBOM_) { + int pattern = (c[0] ? 1 : 0) | (c[1] ? 2 : 0) | (c[2] ? 4 : 0) | (c[3] ? 8 : 0); + switch (pattern) { + case 0x08: type_ = kUTF32BE; break; + case 0x0A: type_ = kUTF16BE; break; + case 0x01: type_ = kUTF32LE; break; + case 0x05: type_ = kUTF16LE; break; + case 0x0F: type_ = kUTF8; break; + default: break; // Use type defined by user. + } + } + + // Runtime check whether the size of character type is sufficient. It only perform checks with assertion. + if (type_ == kUTF16LE || type_ == kUTF16BE) RAPIDJSON_ASSERT(sizeof(Ch) >= 2); + if (type_ == kUTF32LE || type_ == kUTF32BE) RAPIDJSON_ASSERT(sizeof(Ch) >= 4); + } + + typedef Ch (*TakeFunc)(InputByteStream& is); + InputByteStream* is_; + UTFType type_; + Ch current_; + TakeFunc takeFunc_; + bool hasBOM_; +}; + +//! Output stream wrapper with dynamically bound encoding and automatic encoding detection. +/*! + \tparam CharType Type of character for writing. + \tparam OutputByteStream type of output byte stream to be wrapped. +*/ +template +class AutoUTFOutputStream { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); +public: + typedef CharType Ch; + + //! Constructor. + /*! + \param os output stream to be wrapped. + \param type UTF encoding type. + \param putBOM Whether to write BOM at the beginning of the stream. + */ + AutoUTFOutputStream(OutputByteStream& os, UTFType type, bool putBOM) : os_(&os), type_(type) { + RAPIDJSON_ASSERT(type >= kUTF8 && type <= kUTF32BE); + + // Runtime check whether the size of character type is sufficient. It only perform checks with assertion. + if (type_ == kUTF16LE || type_ == kUTF16BE) RAPIDJSON_ASSERT(sizeof(Ch) >= 2); + if (type_ == kUTF32LE || type_ == kUTF32BE) RAPIDJSON_ASSERT(sizeof(Ch) >= 4); + + static const PutFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Put) }; + putFunc_ = f[type_]; + + if (putBOM) + PutBOM(); + } + + UTFType GetType() const { return type_; } + + void Put(Ch c) { putFunc_(*os_, c); } + void Flush() { os_->Flush(); } + + // Not implemented + Ch Peek() const { RAPIDJSON_ASSERT(false); return 0;} + Ch Take() { RAPIDJSON_ASSERT(false); return 0;} + size_t Tell() const { RAPIDJSON_ASSERT(false); return 0; } + Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; } + size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; } + +private: + AutoUTFOutputStream(const AutoUTFOutputStream&); + AutoUTFOutputStream& operator=(const AutoUTFOutputStream&); + + void PutBOM() { + typedef void (*PutBOMFunc)(OutputByteStream&); + static const PutBOMFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(PutBOM) }; + f[type_](*os_); + } + + typedef void (*PutFunc)(OutputByteStream&, Ch); + + OutputByteStream* os_; + UTFType type_; + PutFunc putFunc_; +}; + +#undef RAPIDJSON_ENCODINGS_FUNC + +RAPIDJSON_NAMESPACE_END + +#ifdef __clang__ +RAPIDJSON_DIAG_POP +#endif + +#ifdef __GNUC__ +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_FILESTREAM_H_ diff --git a/include/rapidjson/encodings.h b/include/rapidjson/encodings.h new file mode 100644 index 0000000000..c453c0da31 --- /dev/null +++ b/include/rapidjson/encodings.h @@ -0,0 +1,716 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_ENCODINGS_H_ +#define RAPIDJSON_ENCODINGS_H_ + +#include "rapidjson.h" + +#if defined(_MSC_VER) && !defined(__clang__) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(4244) // conversion from 'type1' to 'type2', possible loss of data +RAPIDJSON_DIAG_OFF(4702) // unreachable code +#elif defined(__GNUC__) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(effc++) +RAPIDJSON_DIAG_OFF(overflow) +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +/////////////////////////////////////////////////////////////////////////////// +// Encoding + +/*! \class rapidjson::Encoding + \brief Concept for encoding of Unicode characters. + +\code +concept Encoding { + typename Ch; //! Type of character. A "character" is actually a code unit in unicode's definition. + + enum { supportUnicode = 1 }; // or 0 if not supporting unicode + + //! \brief Encode a Unicode codepoint to an output stream. + //! \param os Output stream. + //! \param codepoint An unicode codepoint, ranging from 0x0 to 0x10FFFF inclusively. + template + static void Encode(OutputStream& os, unsigned codepoint); + + //! \brief Decode a Unicode codepoint from an input stream. + //! \param is Input stream. + //! \param codepoint Output of the unicode codepoint. + //! \return true if a valid codepoint can be decoded from the stream. + template + static bool Decode(InputStream& is, unsigned* codepoint); + + //! \brief Validate one Unicode codepoint from an encoded stream. + //! \param is Input stream to obtain codepoint. + //! \param os Output for copying one codepoint. + //! \return true if it is valid. + //! \note This function just validating and copying the codepoint without actually decode it. + template + static bool Validate(InputStream& is, OutputStream& os); + + // The following functions are deal with byte streams. + + //! Take a character from input byte stream, skip BOM if exist. + template + static CharType TakeBOM(InputByteStream& is); + + //! Take a character from input byte stream. + template + static Ch Take(InputByteStream& is); + + //! Put BOM to output byte stream. + template + static void PutBOM(OutputByteStream& os); + + //! Put a character to output byte stream. + template + static void Put(OutputByteStream& os, Ch c); +}; +\endcode +*/ + +/////////////////////////////////////////////////////////////////////////////// +// UTF8 + +//! UTF-8 encoding. +/*! http://en.wikipedia.org/wiki/UTF-8 + http://tools.ietf.org/html/rfc3629 + \tparam CharType Code unit for storing 8-bit UTF-8 data. Default is char. + \note implements Encoding concept +*/ +template +struct UTF8 { + typedef CharType Ch; + + enum { supportUnicode = 1 }; + + template + static void Encode(OutputStream& os, unsigned codepoint) { + if (codepoint <= 0x7F) + os.Put(static_cast(codepoint & 0xFF)); + else if (codepoint <= 0x7FF) { + os.Put(static_cast(0xC0 | ((codepoint >> 6) & 0xFF))); + os.Put(static_cast(0x80 | ((codepoint & 0x3F)))); + } + else if (codepoint <= 0xFFFF) { + os.Put(static_cast(0xE0 | ((codepoint >> 12) & 0xFF))); + os.Put(static_cast(0x80 | ((codepoint >> 6) & 0x3F))); + os.Put(static_cast(0x80 | (codepoint & 0x3F))); + } + else { + RAPIDJSON_ASSERT(codepoint <= 0x10FFFF); + os.Put(static_cast(0xF0 | ((codepoint >> 18) & 0xFF))); + os.Put(static_cast(0x80 | ((codepoint >> 12) & 0x3F))); + os.Put(static_cast(0x80 | ((codepoint >> 6) & 0x3F))); + os.Put(static_cast(0x80 | (codepoint & 0x3F))); + } + } + + template + static void EncodeUnsafe(OutputStream& os, unsigned codepoint) { + if (codepoint <= 0x7F) + PutUnsafe(os, static_cast(codepoint & 0xFF)); + else if (codepoint <= 0x7FF) { + PutUnsafe(os, static_cast(0xC0 | ((codepoint >> 6) & 0xFF))); + PutUnsafe(os, static_cast(0x80 | ((codepoint & 0x3F)))); + } + else if (codepoint <= 0xFFFF) { + PutUnsafe(os, static_cast(0xE0 | ((codepoint >> 12) & 0xFF))); + PutUnsafe(os, static_cast(0x80 | ((codepoint >> 6) & 0x3F))); + PutUnsafe(os, static_cast(0x80 | (codepoint & 0x3F))); + } + else { + RAPIDJSON_ASSERT(codepoint <= 0x10FFFF); + PutUnsafe(os, static_cast(0xF0 | ((codepoint >> 18) & 0xFF))); + PutUnsafe(os, static_cast(0x80 | ((codepoint >> 12) & 0x3F))); + PutUnsafe(os, static_cast(0x80 | ((codepoint >> 6) & 0x3F))); + PutUnsafe(os, static_cast(0x80 | (codepoint & 0x3F))); + } + } + + template + static bool Decode(InputStream& is, unsigned* codepoint) { +#define RAPIDJSON_COPY() c = is.Take(); *codepoint = (*codepoint << 6) | (static_cast(c) & 0x3Fu) +#define RAPIDJSON_TRANS(mask) result &= ((GetRange(static_cast(c)) & mask) != 0) +#define RAPIDJSON_TAIL() RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x70) + typename InputStream::Ch c = is.Take(); + if (!(c & 0x80)) { + *codepoint = static_cast(c); + return true; + } + + unsigned char type = GetRange(static_cast(c)); + if (type >= 32) { + *codepoint = 0; + } else { + *codepoint = (0xFFu >> type) & static_cast(c); + } + bool result = true; + switch (type) { + case 2: RAPIDJSON_TAIL(); return result; + case 3: RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result; + case 4: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x50); RAPIDJSON_TAIL(); return result; + case 5: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x10); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result; + case 6: RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result; + case 10: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x20); RAPIDJSON_TAIL(); return result; + case 11: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x60); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result; + default: return false; + } +#undef RAPIDJSON_COPY +#undef RAPIDJSON_TRANS +#undef RAPIDJSON_TAIL + } + + template + static bool Validate(InputStream& is, OutputStream& os) { +#define RAPIDJSON_COPY() if (c != '\0') os.Put(c = is.Take()) +#define RAPIDJSON_TRANS(mask) result &= ((GetRange(static_cast(c)) & mask) != 0) +#define RAPIDJSON_TAIL() RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x70) + Ch c = static_cast(-1); + RAPIDJSON_COPY(); + if (!(c & 0x80)) + return true; + + bool result = true; + switch (GetRange(static_cast(c))) { + case 2: RAPIDJSON_TAIL(); return result; + case 3: RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result; + case 4: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x50); RAPIDJSON_TAIL(); return result; + case 5: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x10); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result; + case 6: RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result; + case 10: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x20); RAPIDJSON_TAIL(); return result; + case 11: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x60); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result; + default: return false; + } +#undef RAPIDJSON_COPY +#undef RAPIDJSON_TRANS +#undef RAPIDJSON_TAIL + } + + static unsigned char GetRange(unsigned char c) { + // Referring to DFA of http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ + // With new mapping 1 -> 0x10, 7 -> 0x20, 9 -> 0x40, such that AND operation can test multiple types. + static const unsigned char type[] = { + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, + 0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10, + 0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40, + 0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20, + 0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20, + 8,8,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, + 10,3,3,3,3,3,3,3,3,3,3,3,3,4,3,3, 11,6,6,6,5,8,8,8,8,8,8,8,8,8,8,8, + }; + return type[c]; + } + + template + static CharType TakeBOM(InputByteStream& is) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); + typename InputByteStream::Ch c = Take(is); + if (static_cast(c) != 0xEFu) return c; + c = is.Take(); + if (static_cast(c) != 0xBBu) return c; + c = is.Take(); + if (static_cast(c) != 0xBFu) return c; + c = is.Take(); + return c; + } + + template + static Ch Take(InputByteStream& is) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); + return static_cast(is.Take()); + } + + template + static void PutBOM(OutputByteStream& os) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); + os.Put(static_cast(0xEFu)); + os.Put(static_cast(0xBBu)); + os.Put(static_cast(0xBFu)); + } + + template + static void Put(OutputByteStream& os, Ch c) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); + os.Put(static_cast(c)); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// UTF16 + +//! UTF-16 encoding. +/*! http://en.wikipedia.org/wiki/UTF-16 + http://tools.ietf.org/html/rfc2781 + \tparam CharType Type for storing 16-bit UTF-16 data. Default is wchar_t. C++11 may use char16_t instead. + \note implements Encoding concept + + \note For in-memory access, no need to concern endianness. The code units and code points are represented by CPU's endianness. + For streaming, use UTF16LE and UTF16BE, which handle endianness. +*/ +template +struct UTF16 { + typedef CharType Ch; + RAPIDJSON_STATIC_ASSERT(sizeof(Ch) >= 2); + + enum { supportUnicode = 1 }; + + template + static void Encode(OutputStream& os, unsigned codepoint) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 2); + if (codepoint <= 0xFFFF) { + RAPIDJSON_ASSERT(codepoint < 0xD800 || codepoint > 0xDFFF); // Code point itself cannot be surrogate pair + os.Put(static_cast(codepoint)); + } + else { + RAPIDJSON_ASSERT(codepoint <= 0x10FFFF); + unsigned v = codepoint - 0x10000; + os.Put(static_cast((v >> 10) | 0xD800)); + os.Put(static_cast((v & 0x3FF) | 0xDC00)); + } + } + + + template + static void EncodeUnsafe(OutputStream& os, unsigned codepoint) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 2); + if (codepoint <= 0xFFFF) { + RAPIDJSON_ASSERT(codepoint < 0xD800 || codepoint > 0xDFFF); // Code point itself cannot be surrogate pair + PutUnsafe(os, static_cast(codepoint)); + } + else { + RAPIDJSON_ASSERT(codepoint <= 0x10FFFF); + unsigned v = codepoint - 0x10000; + PutUnsafe(os, static_cast((v >> 10) | 0xD800)); + PutUnsafe(os, static_cast((v & 0x3FF) | 0xDC00)); + } + } + + template + static bool Decode(InputStream& is, unsigned* codepoint) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputStream::Ch) >= 2); + typename InputStream::Ch c = is.Take(); + if (c < 0xD800 || c > 0xDFFF) { + *codepoint = static_cast(c); + return true; + } + else if (c <= 0xDBFF) { + *codepoint = (static_cast(c) & 0x3FF) << 10; + c = is.Take(); + *codepoint |= (static_cast(c) & 0x3FF); + *codepoint += 0x10000; + return c >= 0xDC00 && c <= 0xDFFF; + } + return false; + } + + template + static bool Validate(InputStream& is, OutputStream& os) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputStream::Ch) >= 2); + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 2); + typename InputStream::Ch c; + os.Put(static_cast(c = is.Take())); + if (c < 0xD800 || c > 0xDFFF) + return true; + else if (c <= 0xDBFF) { + os.Put(c = is.Take()); + return c >= 0xDC00 && c <= 0xDFFF; + } + return false; + } +}; + +//! UTF-16 little endian encoding. +template +struct UTF16LE : UTF16 { + template + static CharType TakeBOM(InputByteStream& is) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); + CharType c = Take(is); + return static_cast(c) == 0xFEFFu ? Take(is) : c; + } + + template + static CharType Take(InputByteStream& is) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); + unsigned c = static_cast(is.Take()); + c |= static_cast(static_cast(is.Take())) << 8; + return static_cast(c); + } + + template + static void PutBOM(OutputByteStream& os) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); + os.Put(static_cast(0xFFu)); + os.Put(static_cast(0xFEu)); + } + + template + static void Put(OutputByteStream& os, CharType c) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); + os.Put(static_cast(static_cast(c) & 0xFFu)); + os.Put(static_cast((static_cast(c) >> 8) & 0xFFu)); + } +}; + +//! UTF-16 big endian encoding. +template +struct UTF16BE : UTF16 { + template + static CharType TakeBOM(InputByteStream& is) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); + CharType c = Take(is); + return static_cast(c) == 0xFEFFu ? Take(is) : c; + } + + template + static CharType Take(InputByteStream& is) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); + unsigned c = static_cast(static_cast(is.Take())) << 8; + c |= static_cast(static_cast(is.Take())); + return static_cast(c); + } + + template + static void PutBOM(OutputByteStream& os) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); + os.Put(static_cast(0xFEu)); + os.Put(static_cast(0xFFu)); + } + + template + static void Put(OutputByteStream& os, CharType c) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); + os.Put(static_cast((static_cast(c) >> 8) & 0xFFu)); + os.Put(static_cast(static_cast(c) & 0xFFu)); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// UTF32 + +//! UTF-32 encoding. +/*! http://en.wikipedia.org/wiki/UTF-32 + \tparam CharType Type for storing 32-bit UTF-32 data. Default is unsigned. C++11 may use char32_t instead. + \note implements Encoding concept + + \note For in-memory access, no need to concern endianness. The code units and code points are represented by CPU's endianness. + For streaming, use UTF32LE and UTF32BE, which handle endianness. +*/ +template +struct UTF32 { + typedef CharType Ch; + RAPIDJSON_STATIC_ASSERT(sizeof(Ch) >= 4); + + enum { supportUnicode = 1 }; + + template + static void Encode(OutputStream& os, unsigned codepoint) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 4); + RAPIDJSON_ASSERT(codepoint <= 0x10FFFF); + os.Put(codepoint); + } + + template + static void EncodeUnsafe(OutputStream& os, unsigned codepoint) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 4); + RAPIDJSON_ASSERT(codepoint <= 0x10FFFF); + PutUnsafe(os, codepoint); + } + + template + static bool Decode(InputStream& is, unsigned* codepoint) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputStream::Ch) >= 4); + Ch c = is.Take(); + *codepoint = c; + return c <= 0x10FFFF; + } + + template + static bool Validate(InputStream& is, OutputStream& os) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputStream::Ch) >= 4); + Ch c; + os.Put(c = is.Take()); + return c <= 0x10FFFF; + } +}; + +//! UTF-32 little endian enocoding. +template +struct UTF32LE : UTF32 { + template + static CharType TakeBOM(InputByteStream& is) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); + CharType c = Take(is); + return static_cast(c) == 0x0000FEFFu ? Take(is) : c; + } + + template + static CharType Take(InputByteStream& is) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); + unsigned c = static_cast(is.Take()); + c |= static_cast(static_cast(is.Take())) << 8; + c |= static_cast(static_cast(is.Take())) << 16; + c |= static_cast(static_cast(is.Take())) << 24; + return static_cast(c); + } + + template + static void PutBOM(OutputByteStream& os) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); + os.Put(static_cast(0xFFu)); + os.Put(static_cast(0xFEu)); + os.Put(static_cast(0x00u)); + os.Put(static_cast(0x00u)); + } + + template + static void Put(OutputByteStream& os, CharType c) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); + os.Put(static_cast(c & 0xFFu)); + os.Put(static_cast((c >> 8) & 0xFFu)); + os.Put(static_cast((c >> 16) & 0xFFu)); + os.Put(static_cast((c >> 24) & 0xFFu)); + } +}; + +//! UTF-32 big endian encoding. +template +struct UTF32BE : UTF32 { + template + static CharType TakeBOM(InputByteStream& is) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); + CharType c = Take(is); + return static_cast(c) == 0x0000FEFFu ? Take(is) : c; + } + + template + static CharType Take(InputByteStream& is) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); + unsigned c = static_cast(static_cast(is.Take())) << 24; + c |= static_cast(static_cast(is.Take())) << 16; + c |= static_cast(static_cast(is.Take())) << 8; + c |= static_cast(static_cast(is.Take())); + return static_cast(c); + } + + template + static void PutBOM(OutputByteStream& os) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); + os.Put(static_cast(0x00u)); + os.Put(static_cast(0x00u)); + os.Put(static_cast(0xFEu)); + os.Put(static_cast(0xFFu)); + } + + template + static void Put(OutputByteStream& os, CharType c) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); + os.Put(static_cast((c >> 24) & 0xFFu)); + os.Put(static_cast((c >> 16) & 0xFFu)); + os.Put(static_cast((c >> 8) & 0xFFu)); + os.Put(static_cast(c & 0xFFu)); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// ASCII + +//! ASCII encoding. +/*! http://en.wikipedia.org/wiki/ASCII + \tparam CharType Code unit for storing 7-bit ASCII data. Default is char. + \note implements Encoding concept +*/ +template +struct ASCII { + typedef CharType Ch; + + enum { supportUnicode = 0 }; + + template + static void Encode(OutputStream& os, unsigned codepoint) { + RAPIDJSON_ASSERT(codepoint <= 0x7F); + os.Put(static_cast(codepoint & 0xFF)); + } + + template + static void EncodeUnsafe(OutputStream& os, unsigned codepoint) { + RAPIDJSON_ASSERT(codepoint <= 0x7F); + PutUnsafe(os, static_cast(codepoint & 0xFF)); + } + + template + static bool Decode(InputStream& is, unsigned* codepoint) { + uint8_t c = static_cast(is.Take()); + *codepoint = c; + return c <= 0X7F; + } + + template + static bool Validate(InputStream& is, OutputStream& os) { + uint8_t c = static_cast(is.Take()); + os.Put(static_cast(c)); + return c <= 0x7F; + } + + template + static CharType TakeBOM(InputByteStream& is) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); + uint8_t c = static_cast(Take(is)); + return static_cast(c); + } + + template + static Ch Take(InputByteStream& is) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1); + return static_cast(is.Take()); + } + + template + static void PutBOM(OutputByteStream& os) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); + (void)os; + } + + template + static void Put(OutputByteStream& os, Ch c) { + RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1); + os.Put(static_cast(c)); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// AutoUTF + +//! Runtime-specified UTF encoding type of a stream. +enum UTFType { + kUTF8 = 0, //!< UTF-8. + kUTF16LE = 1, //!< UTF-16 little endian. + kUTF16BE = 2, //!< UTF-16 big endian. + kUTF32LE = 3, //!< UTF-32 little endian. + kUTF32BE = 4 //!< UTF-32 big endian. +}; + +//! Dynamically select encoding according to stream's runtime-specified UTF encoding type. +/*! \note This class can be used with AutoUTFInputtStream and AutoUTFOutputStream, which provides GetType(). +*/ +template +struct AutoUTF { + typedef CharType Ch; + + enum { supportUnicode = 1 }; + +#define RAPIDJSON_ENCODINGS_FUNC(x) UTF8::x, UTF16LE::x, UTF16BE::x, UTF32LE::x, UTF32BE::x + + template + static RAPIDJSON_FORCEINLINE void Encode(OutputStream& os, unsigned codepoint) { + typedef void (*EncodeFunc)(OutputStream&, unsigned); + static const EncodeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Encode) }; + (*f[os.GetType()])(os, codepoint); + } + + template + static RAPIDJSON_FORCEINLINE void EncodeUnsafe(OutputStream& os, unsigned codepoint) { + typedef void (*EncodeFunc)(OutputStream&, unsigned); + static const EncodeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(EncodeUnsafe) }; + (*f[os.GetType()])(os, codepoint); + } + + template + static RAPIDJSON_FORCEINLINE bool Decode(InputStream& is, unsigned* codepoint) { + typedef bool (*DecodeFunc)(InputStream&, unsigned*); + static const DecodeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Decode) }; + return (*f[is.GetType()])(is, codepoint); + } + + template + static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os) { + typedef bool (*ValidateFunc)(InputStream&, OutputStream&); + static const ValidateFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Validate) }; + return (*f[is.GetType()])(is, os); + } + +#undef RAPIDJSON_ENCODINGS_FUNC +}; + +/////////////////////////////////////////////////////////////////////////////// +// Transcoder + +//! Encoding conversion. +template +struct Transcoder { + //! Take one Unicode codepoint from source encoding, convert it to target encoding and put it to the output stream. + template + static RAPIDJSON_FORCEINLINE bool Transcode(InputStream& is, OutputStream& os) { + unsigned codepoint; + if (!SourceEncoding::Decode(is, &codepoint)) + return false; + TargetEncoding::Encode(os, codepoint); + return true; + } + + template + static RAPIDJSON_FORCEINLINE bool TranscodeUnsafe(InputStream& is, OutputStream& os) { + unsigned codepoint; + if (!SourceEncoding::Decode(is, &codepoint)) + return false; + TargetEncoding::EncodeUnsafe(os, codepoint); + return true; + } + + //! Validate one Unicode codepoint from an encoded stream. + template + static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os) { + return Transcode(is, os); // Since source/target encoding is different, must transcode. + } +}; + +// Forward declaration. +template +inline void PutUnsafe(Stream& stream, typename Stream::Ch c); + +//! Specialization of Transcoder with same source and target encoding. +template +struct Transcoder { + template + static RAPIDJSON_FORCEINLINE bool Transcode(InputStream& is, OutputStream& os) { + os.Put(is.Take()); // Just copy one code unit. This semantic is different from primary template class. + return true; + } + + template + static RAPIDJSON_FORCEINLINE bool TranscodeUnsafe(InputStream& is, OutputStream& os) { + PutUnsafe(os, is.Take()); // Just copy one code unit. This semantic is different from primary template class. + return true; + } + + template + static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os) { + return Encoding::Validate(is, os); // source/target encoding are the same + } +}; + +RAPIDJSON_NAMESPACE_END + +#if defined(__GNUC__) || (defined(_MSC_VER) && !defined(__clang__)) +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_ENCODINGS_H_ diff --git a/include/rapidjson/error/en.h b/include/rapidjson/error/en.h new file mode 100644 index 0000000000..c87b04eb13 --- /dev/null +++ b/include/rapidjson/error/en.h @@ -0,0 +1,176 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_ERROR_EN_H_ +#define RAPIDJSON_ERROR_EN_H_ + +#include "error.h" + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(switch-enum) +RAPIDJSON_DIAG_OFF(covered-switch-default) +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +//! Maps error code of parsing into error message. +/*! + \ingroup RAPIDJSON_ERRORS + \param parseErrorCode Error code obtained in parsing. + \return the error message. + \note User can make a copy of this function for localization. + Using switch-case is safer for future modification of error codes. +*/ +inline const RAPIDJSON_ERROR_CHARTYPE* GetParseError_En(ParseErrorCode parseErrorCode) { + switch (parseErrorCode) { + case kParseErrorNone: return RAPIDJSON_ERROR_STRING("No error."); + + case kParseErrorDocumentEmpty: return RAPIDJSON_ERROR_STRING("The document is empty."); + case kParseErrorDocumentRootNotSingular: return RAPIDJSON_ERROR_STRING("The document root must not be followed by other values."); + + case kParseErrorValueInvalid: return RAPIDJSON_ERROR_STRING("Invalid value."); + + case kParseErrorObjectMissName: return RAPIDJSON_ERROR_STRING("Missing a name for object member."); + case kParseErrorObjectMissColon: return RAPIDJSON_ERROR_STRING("Missing a colon after a name of object member."); + case kParseErrorObjectMissCommaOrCurlyBracket: return RAPIDJSON_ERROR_STRING("Missing a comma or '}' after an object member."); + + case kParseErrorArrayMissCommaOrSquareBracket: return RAPIDJSON_ERROR_STRING("Missing a comma or ']' after an array element."); + + case kParseErrorStringUnicodeEscapeInvalidHex: return RAPIDJSON_ERROR_STRING("Incorrect hex digit after \\u escape in string."); + case kParseErrorStringUnicodeSurrogateInvalid: return RAPIDJSON_ERROR_STRING("The surrogate pair in string is invalid."); + case kParseErrorStringEscapeInvalid: return RAPIDJSON_ERROR_STRING("Invalid escape character in string."); + case kParseErrorStringMissQuotationMark: return RAPIDJSON_ERROR_STRING("Missing a closing quotation mark in string."); + case kParseErrorStringInvalidEncoding: return RAPIDJSON_ERROR_STRING("Invalid encoding in string."); + + case kParseErrorNumberTooBig: return RAPIDJSON_ERROR_STRING("Number too big to be stored in double."); + case kParseErrorNumberMissFraction: return RAPIDJSON_ERROR_STRING("Miss fraction part in number."); + case kParseErrorNumberMissExponent: return RAPIDJSON_ERROR_STRING("Miss exponent in number."); + + case kParseErrorTermination: return RAPIDJSON_ERROR_STRING("Terminate parsing due to Handler error."); + case kParseErrorUnspecificSyntaxError: return RAPIDJSON_ERROR_STRING("Unspecific syntax error."); + + default: return RAPIDJSON_ERROR_STRING("Unknown error."); + } +} + +//! Maps error code of validation into error message. +/*! + \ingroup RAPIDJSON_ERRORS + \param validateErrorCode Error code obtained from validator. + \return the error message. + \note User can make a copy of this function for localization. + Using switch-case is safer for future modification of error codes. +*/ +inline const RAPIDJSON_ERROR_CHARTYPE* GetValidateError_En(ValidateErrorCode validateErrorCode) { + switch (validateErrorCode) { + case kValidateErrors: return RAPIDJSON_ERROR_STRING("One or more validation errors have occurred"); + case kValidateErrorNone: return RAPIDJSON_ERROR_STRING("No error."); + + case kValidateErrorMultipleOf: return RAPIDJSON_ERROR_STRING("Number '%actual' is not a multiple of the 'multipleOf' value '%expected'."); + case kValidateErrorMaximum: return RAPIDJSON_ERROR_STRING("Number '%actual' is greater than the 'maximum' value '%expected'."); + case kValidateErrorExclusiveMaximum: return RAPIDJSON_ERROR_STRING("Number '%actual' is greater than or equal to the 'exclusiveMaximum' value '%expected'."); + case kValidateErrorMinimum: return RAPIDJSON_ERROR_STRING("Number '%actual' is less than the 'minimum' value '%expected'."); + case kValidateErrorExclusiveMinimum: return RAPIDJSON_ERROR_STRING("Number '%actual' is less than or equal to the 'exclusiveMinimum' value '%expected'."); + + case kValidateErrorMaxLength: return RAPIDJSON_ERROR_STRING("String '%actual' is longer than the 'maxLength' value '%expected'."); + case kValidateErrorMinLength: return RAPIDJSON_ERROR_STRING("String '%actual' is shorter than the 'minLength' value '%expected'."); + case kValidateErrorPattern: return RAPIDJSON_ERROR_STRING("String '%actual' does not match the 'pattern' regular expression."); + + case kValidateErrorMaxItems: return RAPIDJSON_ERROR_STRING("Array of length '%actual' is longer than the 'maxItems' value '%expected'."); + case kValidateErrorMinItems: return RAPIDJSON_ERROR_STRING("Array of length '%actual' is shorter than the 'minItems' value '%expected'."); + case kValidateErrorUniqueItems: return RAPIDJSON_ERROR_STRING("Array has duplicate items at indices '%duplicates' but 'uniqueItems' is true."); + case kValidateErrorAdditionalItems: return RAPIDJSON_ERROR_STRING("Array has an additional item at index '%disallowed' that is not allowed by the schema."); + + case kValidateErrorMaxProperties: return RAPIDJSON_ERROR_STRING("Object has '%actual' members which is more than 'maxProperties' value '%expected'."); + case kValidateErrorMinProperties: return RAPIDJSON_ERROR_STRING("Object has '%actual' members which is less than 'minProperties' value '%expected'."); + case kValidateErrorRequired: return RAPIDJSON_ERROR_STRING("Object is missing the following members required by the schema: '%missing'."); + case kValidateErrorAdditionalProperties: return RAPIDJSON_ERROR_STRING("Object has an additional member '%disallowed' that is not allowed by the schema."); + case kValidateErrorPatternProperties: return RAPIDJSON_ERROR_STRING("Object has 'patternProperties' that are not allowed by the schema."); + case kValidateErrorDependencies: return RAPIDJSON_ERROR_STRING("Object has missing property or schema dependencies, refer to following errors."); + + case kValidateErrorEnum: return RAPIDJSON_ERROR_STRING("Property has a value that is not one of its allowed enumerated values."); + case kValidateErrorType: return RAPIDJSON_ERROR_STRING("Property has a type '%actual' that is not in the following list: '%expected'."); + + case kValidateErrorOneOf: return RAPIDJSON_ERROR_STRING("Property did not match any of the sub-schemas specified by 'oneOf', refer to following errors."); + case kValidateErrorOneOfMatch: return RAPIDJSON_ERROR_STRING("Property matched more than one of the sub-schemas specified by 'oneOf', indices '%matches'."); + case kValidateErrorAllOf: return RAPIDJSON_ERROR_STRING("Property did not match all of the sub-schemas specified by 'allOf', refer to following errors."); + case kValidateErrorAnyOf: return RAPIDJSON_ERROR_STRING("Property did not match any of the sub-schemas specified by 'anyOf', refer to following errors."); + case kValidateErrorNot: return RAPIDJSON_ERROR_STRING("Property matched the sub-schema specified by 'not'."); + + case kValidateErrorReadOnly: return RAPIDJSON_ERROR_STRING("Property is read-only but has been provided when validation is for writing."); + case kValidateErrorWriteOnly: return RAPIDJSON_ERROR_STRING("Property is write-only but has been provided when validation is for reading."); + + default: return RAPIDJSON_ERROR_STRING("Unknown error."); + } +} + +//! Maps error code of schema document compilation into error message. +/*! + \ingroup RAPIDJSON_ERRORS + \param schemaErrorCode Error code obtained from compiling the schema document. + \return the error message. + \note User can make a copy of this function for localization. + Using switch-case is safer for future modification of error codes. +*/ + inline const RAPIDJSON_ERROR_CHARTYPE* GetSchemaError_En(SchemaErrorCode schemaErrorCode) { + switch (schemaErrorCode) { + case kSchemaErrorNone: return RAPIDJSON_ERROR_STRING("No error."); + + case kSchemaErrorStartUnknown: return RAPIDJSON_ERROR_STRING("Pointer '%value' to start of schema does not resolve to a location in the document."); + case kSchemaErrorRefPlainName: return RAPIDJSON_ERROR_STRING("$ref fragment '%value' must be a JSON pointer."); + case kSchemaErrorRefInvalid: return RAPIDJSON_ERROR_STRING("$ref must not be an empty string."); + case kSchemaErrorRefPointerInvalid: return RAPIDJSON_ERROR_STRING("$ref fragment '%value' is not a valid JSON pointer at offset '%offset'."); + case kSchemaErrorRefUnknown: return RAPIDJSON_ERROR_STRING("$ref '%value' does not resolve to a location in the target document."); + case kSchemaErrorRefCyclical: return RAPIDJSON_ERROR_STRING("$ref '%value' is cyclical."); + case kSchemaErrorRefNoRemoteProvider: return RAPIDJSON_ERROR_STRING("$ref is remote but there is no remote provider."); + case kSchemaErrorRefNoRemoteSchema: return RAPIDJSON_ERROR_STRING("$ref '%value' is remote but the remote provider did not return a schema."); + case kSchemaErrorRegexInvalid: return RAPIDJSON_ERROR_STRING("Invalid regular expression '%value' in 'pattern' or 'patternProperties'."); + case kSchemaErrorSpecUnknown: return RAPIDJSON_ERROR_STRING("JSON schema draft or OpenAPI version is not recognized."); + case kSchemaErrorSpecUnsupported: return RAPIDJSON_ERROR_STRING("JSON schema draft or OpenAPI version is not supported."); + case kSchemaErrorSpecIllegal: return RAPIDJSON_ERROR_STRING("Both JSON schema draft and OpenAPI version found in document."); + case kSchemaErrorReadOnlyAndWriteOnly: return RAPIDJSON_ERROR_STRING("Property must not be both 'readOnly' and 'writeOnly'."); + + default: return RAPIDJSON_ERROR_STRING("Unknown error."); + } + } + +//! Maps error code of pointer parse into error message. +/*! + \ingroup RAPIDJSON_ERRORS + \param pointerParseErrorCode Error code obtained from pointer parse. + \return the error message. + \note User can make a copy of this function for localization. + Using switch-case is safer for future modification of error codes. +*/ +inline const RAPIDJSON_ERROR_CHARTYPE* GetPointerParseError_En(PointerParseErrorCode pointerParseErrorCode) { + switch (pointerParseErrorCode) { + case kPointerParseErrorNone: return RAPIDJSON_ERROR_STRING("No error."); + + case kPointerParseErrorTokenMustBeginWithSolidus: return RAPIDJSON_ERROR_STRING("A token must begin with a '/'."); + case kPointerParseErrorInvalidEscape: return RAPIDJSON_ERROR_STRING("Invalid escape."); + case kPointerParseErrorInvalidPercentEncoding: return RAPIDJSON_ERROR_STRING("Invalid percent encoding in URI fragment."); + case kPointerParseErrorCharacterMustPercentEncode: return RAPIDJSON_ERROR_STRING("A character must be percent encoded in a URI fragment."); + + default: return RAPIDJSON_ERROR_STRING("Unknown error."); + } +} + +RAPIDJSON_NAMESPACE_END + +#ifdef __clang__ +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_ERROR_EN_H_ diff --git a/include/rapidjson/error/error.h b/include/rapidjson/error/error.h new file mode 100644 index 0000000000..cae345db36 --- /dev/null +++ b/include/rapidjson/error/error.h @@ -0,0 +1,285 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_ERROR_ERROR_H_ +#define RAPIDJSON_ERROR_ERROR_H_ + +#include "../rapidjson.h" + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(padded) +#endif + +/*! \file error.h */ + +/*! \defgroup RAPIDJSON_ERRORS RapidJSON error handling */ + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_ERROR_CHARTYPE + +//! Character type of error messages. +/*! \ingroup RAPIDJSON_ERRORS + The default character type is \c char. + On Windows, user can define this macro as \c TCHAR for supporting both + unicode/non-unicode settings. +*/ +#ifndef RAPIDJSON_ERROR_CHARTYPE +#define RAPIDJSON_ERROR_CHARTYPE char +#endif + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_ERROR_STRING + +//! Macro for converting string literal to \ref RAPIDJSON_ERROR_CHARTYPE[]. +/*! \ingroup RAPIDJSON_ERRORS + By default this conversion macro does nothing. + On Windows, user can define this macro as \c _T(x) for supporting both + unicode/non-unicode settings. +*/ +#ifndef RAPIDJSON_ERROR_STRING +#define RAPIDJSON_ERROR_STRING(x) x +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +/////////////////////////////////////////////////////////////////////////////// +// ParseErrorCode + +//! Error code of parsing. +/*! \ingroup RAPIDJSON_ERRORS + \see GenericReader::Parse, GenericReader::GetParseErrorCode +*/ +enum ParseErrorCode { + kParseErrorNone = 0, //!< No error. + + kParseErrorDocumentEmpty, //!< The document is empty. + kParseErrorDocumentRootNotSingular, //!< The document root must not follow by other values. + + kParseErrorValueInvalid, //!< Invalid value. + + kParseErrorObjectMissName, //!< Missing a name for object member. + kParseErrorObjectMissColon, //!< Missing a colon after a name of object member. + kParseErrorObjectMissCommaOrCurlyBracket, //!< Missing a comma or '}' after an object member. + + kParseErrorArrayMissCommaOrSquareBracket, //!< Missing a comma or ']' after an array element. + + kParseErrorStringUnicodeEscapeInvalidHex, //!< Incorrect hex digit after \\u escape in string. + kParseErrorStringUnicodeSurrogateInvalid, //!< The surrogate pair in string is invalid. + kParseErrorStringEscapeInvalid, //!< Invalid escape character in string. + kParseErrorStringMissQuotationMark, //!< Missing a closing quotation mark in string. + kParseErrorStringInvalidEncoding, //!< Invalid encoding in string. + + kParseErrorNumberTooBig, //!< Number too big to be stored in double. + kParseErrorNumberMissFraction, //!< Miss fraction part in number. + kParseErrorNumberMissExponent, //!< Miss exponent in number. + + kParseErrorTermination, //!< Parsing was terminated. + kParseErrorUnspecificSyntaxError //!< Unspecific syntax error. +}; + +//! Result of parsing (wraps ParseErrorCode) +/*! + \ingroup RAPIDJSON_ERRORS + \code + Document doc; + ParseResult ok = doc.Parse("[42]"); + if (!ok) { + fprintf(stderr, "JSON parse error: %s (%u)", + GetParseError_En(ok.Code()), ok.Offset()); + exit(EXIT_FAILURE); + } + \endcode + \see GenericReader::Parse, GenericDocument::Parse +*/ +struct ParseResult { + //!! Unspecified boolean type + typedef bool (ParseResult::*BooleanType)() const; +public: + //! Default constructor, no error. + ParseResult() : code_(kParseErrorNone), offset_(0) {} + //! Constructor to set an error. + ParseResult(ParseErrorCode code, size_t offset) : code_(code), offset_(offset) {} + + //! Get the error code. + ParseErrorCode Code() const { return code_; } + //! Get the error offset, if \ref IsError(), 0 otherwise. + size_t Offset() const { return offset_; } + + //! Explicit conversion to \c bool, returns \c true, iff !\ref IsError(). + operator BooleanType() const { return !IsError() ? &ParseResult::IsError : NULL; } + //! Whether the result is an error. + bool IsError() const { return code_ != kParseErrorNone; } + + bool operator==(const ParseResult& that) const { return code_ == that.code_; } + bool operator==(ParseErrorCode code) const { return code_ == code; } + friend bool operator==(ParseErrorCode code, const ParseResult & err) { return code == err.code_; } + + bool operator!=(const ParseResult& that) const { return !(*this == that); } + bool operator!=(ParseErrorCode code) const { return !(*this == code); } + friend bool operator!=(ParseErrorCode code, const ParseResult & err) { return err != code; } + + //! Reset error code. + void Clear() { Set(kParseErrorNone); } + //! Update error code and offset. + void Set(ParseErrorCode code, size_t offset = 0) { code_ = code; offset_ = offset; } + +private: + ParseErrorCode code_; + size_t offset_; +}; + +//! Function pointer type of GetParseError(). +/*! \ingroup RAPIDJSON_ERRORS + + This is the prototype for \c GetParseError_X(), where \c X is a locale. + User can dynamically change locale in runtime, e.g.: +\code + GetParseErrorFunc GetParseError = GetParseError_En; // or whatever + const RAPIDJSON_ERROR_CHARTYPE* s = GetParseError(document.GetParseErrorCode()); +\endcode +*/ +typedef const RAPIDJSON_ERROR_CHARTYPE* (*GetParseErrorFunc)(ParseErrorCode); + +/////////////////////////////////////////////////////////////////////////////// +// ValidateErrorCode + +//! Error codes when validating. +/*! \ingroup RAPIDJSON_ERRORS + \see GenericSchemaValidator +*/ +enum ValidateErrorCode { + kValidateErrors = -1, //!< Top level error code when kValidateContinueOnErrorsFlag set. + kValidateErrorNone = 0, //!< No error. + + kValidateErrorMultipleOf, //!< Number is not a multiple of the 'multipleOf' value. + kValidateErrorMaximum, //!< Number is greater than the 'maximum' value. + kValidateErrorExclusiveMaximum, //!< Number is greater than or equal to the 'maximum' value. + kValidateErrorMinimum, //!< Number is less than the 'minimum' value. + kValidateErrorExclusiveMinimum, //!< Number is less than or equal to the 'minimum' value. + + kValidateErrorMaxLength, //!< String is longer than the 'maxLength' value. + kValidateErrorMinLength, //!< String is longer than the 'maxLength' value. + kValidateErrorPattern, //!< String does not match the 'pattern' regular expression. + + kValidateErrorMaxItems, //!< Array is longer than the 'maxItems' value. + kValidateErrorMinItems, //!< Array is shorter than the 'minItems' value. + kValidateErrorUniqueItems, //!< Array has duplicate items but 'uniqueItems' is true. + kValidateErrorAdditionalItems, //!< Array has additional items that are not allowed by the schema. + + kValidateErrorMaxProperties, //!< Object has more members than 'maxProperties' value. + kValidateErrorMinProperties, //!< Object has less members than 'minProperties' value. + kValidateErrorRequired, //!< Object is missing one or more members required by the schema. + kValidateErrorAdditionalProperties, //!< Object has additional members that are not allowed by the schema. + kValidateErrorPatternProperties, //!< See other errors. + kValidateErrorDependencies, //!< Object has missing property or schema dependencies. + + kValidateErrorEnum, //!< Property has a value that is not one of its allowed enumerated values. + kValidateErrorType, //!< Property has a type that is not allowed by the schema. + + kValidateErrorOneOf, //!< Property did not match any of the sub-schemas specified by 'oneOf'. + kValidateErrorOneOfMatch, //!< Property matched more than one of the sub-schemas specified by 'oneOf'. + kValidateErrorAllOf, //!< Property did not match all of the sub-schemas specified by 'allOf'. + kValidateErrorAnyOf, //!< Property did not match any of the sub-schemas specified by 'anyOf'. + kValidateErrorNot, //!< Property matched the sub-schema specified by 'not'. + + kValidateErrorReadOnly, //!< Property is read-only but has been provided when validation is for writing + kValidateErrorWriteOnly //!< Property is write-only but has been provided when validation is for reading +}; + +//! Function pointer type of GetValidateError(). +/*! \ingroup RAPIDJSON_ERRORS + + This is the prototype for \c GetValidateError_X(), where \c X is a locale. + User can dynamically change locale in runtime, e.g.: +\code + GetValidateErrorFunc GetValidateError = GetValidateError_En; // or whatever + const RAPIDJSON_ERROR_CHARTYPE* s = GetValidateError(validator.GetInvalidSchemaCode()); +\endcode +*/ +typedef const RAPIDJSON_ERROR_CHARTYPE* (*GetValidateErrorFunc)(ValidateErrorCode); + +/////////////////////////////////////////////////////////////////////////////// +// SchemaErrorCode + +//! Error codes when validating. +/*! \ingroup RAPIDJSON_ERRORS + \see GenericSchemaValidator +*/ +enum SchemaErrorCode { + kSchemaErrorNone = 0, //!< No error. + + kSchemaErrorStartUnknown, //!< Pointer to start of schema does not resolve to a location in the document + kSchemaErrorRefPlainName, //!< $ref fragment must be a JSON pointer + kSchemaErrorRefInvalid, //!< $ref must not be an empty string + kSchemaErrorRefPointerInvalid, //!< $ref fragment is not a valid JSON pointer at offset + kSchemaErrorRefUnknown, //!< $ref does not resolve to a location in the target document + kSchemaErrorRefCyclical, //!< $ref is cyclical + kSchemaErrorRefNoRemoteProvider, //!< $ref is remote but there is no remote provider + kSchemaErrorRefNoRemoteSchema, //!< $ref is remote but the remote provider did not return a schema + kSchemaErrorRegexInvalid, //!< Invalid regular expression in 'pattern' or 'patternProperties' + kSchemaErrorSpecUnknown, //!< JSON schema draft or OpenAPI version is not recognized + kSchemaErrorSpecUnsupported, //!< JSON schema draft or OpenAPI version is not supported + kSchemaErrorSpecIllegal, //!< Both JSON schema draft and OpenAPI version found in document + kSchemaErrorReadOnlyAndWriteOnly //!< Property must not be both 'readOnly' and 'writeOnly' +}; + +//! Function pointer type of GetSchemaError(). +/*! \ingroup RAPIDJSON_ERRORS + + This is the prototype for \c GetSchemaError_X(), where \c X is a locale. + User can dynamically change locale in runtime, e.g.: +\code + GetSchemaErrorFunc GetSchemaError = GetSchemaError_En; // or whatever + const RAPIDJSON_ERROR_CHARTYPE* s = GetSchemaError(validator.GetInvalidSchemaCode()); +\endcode +*/ +typedef const RAPIDJSON_ERROR_CHARTYPE* (*GetSchemaErrorFunc)(SchemaErrorCode); + +/////////////////////////////////////////////////////////////////////////////// +// PointerParseErrorCode + +//! Error code of JSON pointer parsing. +/*! \ingroup RAPIDJSON_ERRORS + \see GenericPointer::GenericPointer, GenericPointer::GetParseErrorCode +*/ +enum PointerParseErrorCode { + kPointerParseErrorNone = 0, //!< The parse is successful + + kPointerParseErrorTokenMustBeginWithSolidus, //!< A token must begin with a '/' + kPointerParseErrorInvalidEscape, //!< Invalid escape + kPointerParseErrorInvalidPercentEncoding, //!< Invalid percent encoding in URI fragment + kPointerParseErrorCharacterMustPercentEncode //!< A character must percent encoded in URI fragment +}; + +//! Function pointer type of GetPointerParseError(). +/*! \ingroup RAPIDJSON_ERRORS + + This is the prototype for \c GetPointerParseError_X(), where \c X is a locale. + User can dynamically change locale in runtime, e.g.: +\code + GetPointerParseErrorFunc GetPointerParseError = GetPointerParseError_En; // or whatever + const RAPIDJSON_ERROR_CHARTYPE* s = GetPointerParseError(pointer.GetParseErrorCode()); +\endcode +*/ +typedef const RAPIDJSON_ERROR_CHARTYPE* (*GetPointerParseErrorFunc)(PointerParseErrorCode); + + +RAPIDJSON_NAMESPACE_END + +#ifdef __clang__ +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_ERROR_ERROR_H_ diff --git a/include/rapidjson/filereadstream.h b/include/rapidjson/filereadstream.h new file mode 100644 index 0000000000..f8bb43cb0c --- /dev/null +++ b/include/rapidjson/filereadstream.h @@ -0,0 +1,99 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_FILEREADSTREAM_H_ +#define RAPIDJSON_FILEREADSTREAM_H_ + +#include "stream.h" +#include + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(padded) +RAPIDJSON_DIAG_OFF(unreachable-code) +RAPIDJSON_DIAG_OFF(missing-noreturn) +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +//! File byte stream for input using fread(). +/*! + \note implements Stream concept +*/ +class FileReadStream { +public: + typedef char Ch; //!< Character type (byte). + + //! Constructor. + /*! + \param fp File pointer opened for read. + \param buffer user-supplied buffer. + \param bufferSize size of buffer in bytes. Must >=4 bytes. + */ + FileReadStream(std::FILE* fp, char* buffer, size_t bufferSize) : fp_(fp), buffer_(buffer), bufferSize_(bufferSize), bufferLast_(0), current_(buffer_), readCount_(0), count_(0), eof_(false) { + RAPIDJSON_ASSERT(fp_ != 0); + RAPIDJSON_ASSERT(bufferSize >= 4); + Read(); + } + + Ch Peek() const { return *current_; } + Ch Take() { Ch c = *current_; Read(); return c; } + size_t Tell() const { return count_ + static_cast(current_ - buffer_); } + + // Not implemented + void Put(Ch) { RAPIDJSON_ASSERT(false); } + void Flush() { RAPIDJSON_ASSERT(false); } + Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; } + size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; } + + // For encoding detection only. + const Ch* Peek4() const { + return (current_ + 4 - !eof_ <= bufferLast_) ? current_ : 0; + } + +private: + void Read() { + if (current_ < bufferLast_) + ++current_; + else if (!eof_) { + count_ += readCount_; + readCount_ = std::fread(buffer_, 1, bufferSize_, fp_); + bufferLast_ = buffer_ + readCount_ - 1; + current_ = buffer_; + + if (readCount_ < bufferSize_) { + buffer_[readCount_] = '\0'; + ++bufferLast_; + eof_ = true; + } + } + } + + std::FILE* fp_; + Ch *buffer_; + size_t bufferSize_; + Ch *bufferLast_; + Ch *current_; + size_t readCount_; + size_t count_; //!< Number of characters read + bool eof_; +}; + +RAPIDJSON_NAMESPACE_END + +#ifdef __clang__ +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_FILESTREAM_H_ diff --git a/include/rapidjson/filewritestream.h b/include/rapidjson/filewritestream.h new file mode 100644 index 0000000000..5d89588c21 --- /dev/null +++ b/include/rapidjson/filewritestream.h @@ -0,0 +1,104 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_FILEWRITESTREAM_H_ +#define RAPIDJSON_FILEWRITESTREAM_H_ + +#include "stream.h" +#include + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(unreachable-code) +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +//! Wrapper of C file stream for output using fwrite(). +/*! + \note implements Stream concept +*/ +class FileWriteStream { +public: + typedef char Ch; //!< Character type. Only support char. + + FileWriteStream(std::FILE* fp, char* buffer, size_t bufferSize) : fp_(fp), buffer_(buffer), bufferEnd_(buffer + bufferSize), current_(buffer_) { + RAPIDJSON_ASSERT(fp_ != 0); + } + + void Put(char c) { + if (current_ >= bufferEnd_) + Flush(); + + *current_++ = c; + } + + void PutN(char c, size_t n) { + size_t avail = static_cast(bufferEnd_ - current_); + while (n > avail) { + std::memset(current_, c, avail); + current_ += avail; + Flush(); + n -= avail; + avail = static_cast(bufferEnd_ - current_); + } + + if (n > 0) { + std::memset(current_, c, n); + current_ += n; + } + } + + void Flush() { + if (current_ != buffer_) { + size_t result = std::fwrite(buffer_, 1, static_cast(current_ - buffer_), fp_); + if (result < static_cast(current_ - buffer_)) { + // failure deliberately ignored at this time + // added to avoid warn_unused_result build errors + } + current_ = buffer_; + } + } + + // Not implemented + char Peek() const { RAPIDJSON_ASSERT(false); return 0; } + char Take() { RAPIDJSON_ASSERT(false); return 0; } + size_t Tell() const { RAPIDJSON_ASSERT(false); return 0; } + char* PutBegin() { RAPIDJSON_ASSERT(false); return 0; } + size_t PutEnd(char*) { RAPIDJSON_ASSERT(false); return 0; } + +private: + // Prohibit copy constructor & assignment operator. + FileWriteStream(const FileWriteStream&); + FileWriteStream& operator=(const FileWriteStream&); + + std::FILE* fp_; + char *buffer_; + char *bufferEnd_; + char *current_; +}; + +//! Implement specialized version of PutN() with memset() for better performance. +template<> +inline void PutN(FileWriteStream& stream, char c, size_t n) { + stream.PutN(c, n); +} + +RAPIDJSON_NAMESPACE_END + +#ifdef __clang__ +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_FILESTREAM_H_ diff --git a/include/rapidjson/fwd.h b/include/rapidjson/fwd.h new file mode 100644 index 0000000000..d62f77f0ec --- /dev/null +++ b/include/rapidjson/fwd.h @@ -0,0 +1,151 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_FWD_H_ +#define RAPIDJSON_FWD_H_ + +#include "rapidjson.h" + +RAPIDJSON_NAMESPACE_BEGIN + +// encodings.h + +template struct UTF8; +template struct UTF16; +template struct UTF16BE; +template struct UTF16LE; +template struct UTF32; +template struct UTF32BE; +template struct UTF32LE; +template struct ASCII; +template struct AutoUTF; + +template +struct Transcoder; + +// allocators.h + +class CrtAllocator; + +template +class MemoryPoolAllocator; + +// stream.h + +template +struct GenericStringStream; + +typedef GenericStringStream > StringStream; + +template +struct GenericInsituStringStream; + +typedef GenericInsituStringStream > InsituStringStream; + +// stringbuffer.h + +template +class GenericStringBuffer; + +typedef GenericStringBuffer, CrtAllocator> StringBuffer; + +// filereadstream.h + +class FileReadStream; + +// filewritestream.h + +class FileWriteStream; + +// memorybuffer.h + +template +struct GenericMemoryBuffer; + +typedef GenericMemoryBuffer MemoryBuffer; + +// memorystream.h + +struct MemoryStream; + +// reader.h + +template +struct BaseReaderHandler; + +template +class GenericReader; + +typedef GenericReader, UTF8, CrtAllocator> Reader; + +// writer.h + +template +class Writer; + +// prettywriter.h + +template +class PrettyWriter; + +// document.h + +template +class GenericMember; + +template +class GenericMemberIterator; + +template +struct GenericStringRef; + +template +class GenericValue; + +typedef GenericValue, MemoryPoolAllocator > Value; + +template +class GenericDocument; + +typedef GenericDocument, MemoryPoolAllocator, CrtAllocator> Document; + +// pointer.h + +template +class GenericPointer; + +typedef GenericPointer Pointer; + +// schema.h + +template +class IGenericRemoteSchemaDocumentProvider; + +template +class GenericSchemaDocument; + +typedef GenericSchemaDocument SchemaDocument; +typedef IGenericRemoteSchemaDocumentProvider IRemoteSchemaDocumentProvider; + +template < + typename SchemaDocumentType, + typename OutputHandler, + typename StateAllocator> +class GenericSchemaValidator; + +typedef GenericSchemaValidator, void>, CrtAllocator> SchemaValidator; + +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_RAPIDJSONFWD_H_ diff --git a/include/rapidjson/internal/biginteger.h b/include/rapidjson/internal/biginteger.h new file mode 100644 index 0000000000..4930043dc7 --- /dev/null +++ b/include/rapidjson/internal/biginteger.h @@ -0,0 +1,297 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_BIGINTEGER_H_ +#define RAPIDJSON_BIGINTEGER_H_ + +#include "../rapidjson.h" + +#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) && defined(_M_AMD64) +#include // for _umul128 +#if !defined(_ARM64EC_) +#pragma intrinsic(_umul128) +#else +#pragma comment(lib,"softintrin") +#endif +#endif + +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +class BigInteger { +public: + typedef uint64_t Type; + + BigInteger(const BigInteger& rhs) : count_(rhs.count_) { + std::memcpy(digits_, rhs.digits_, count_ * sizeof(Type)); + } + + explicit BigInteger(uint64_t u) : count_(1) { + digits_[0] = u; + } + + template + BigInteger(const Ch* decimals, size_t length) : count_(1) { + RAPIDJSON_ASSERT(length > 0); + digits_[0] = 0; + size_t i = 0; + const size_t kMaxDigitPerIteration = 19; // 2^64 = 18446744073709551616 > 10^19 + while (length >= kMaxDigitPerIteration) { + AppendDecimal64(decimals + i, decimals + i + kMaxDigitPerIteration); + length -= kMaxDigitPerIteration; + i += kMaxDigitPerIteration; + } + + if (length > 0) + AppendDecimal64(decimals + i, decimals + i + length); + } + + BigInteger& operator=(const BigInteger &rhs) + { + if (this != &rhs) { + count_ = rhs.count_; + std::memcpy(digits_, rhs.digits_, count_ * sizeof(Type)); + } + return *this; + } + + BigInteger& operator=(uint64_t u) { + digits_[0] = u; + count_ = 1; + return *this; + } + + BigInteger& operator+=(uint64_t u) { + Type backup = digits_[0]; + digits_[0] += u; + for (size_t i = 0; i < count_ - 1; i++) { + if (digits_[i] >= backup) + return *this; // no carry + backup = digits_[i + 1]; + digits_[i + 1] += 1; + } + + // Last carry + if (digits_[count_ - 1] < backup) + PushBack(1); + + return *this; + } + + BigInteger& operator*=(uint64_t u) { + if (u == 0) return *this = 0; + if (u == 1) return *this; + if (*this == 1) return *this = u; + + uint64_t k = 0; + for (size_t i = 0; i < count_; i++) { + uint64_t hi; + digits_[i] = MulAdd64(digits_[i], u, k, &hi); + k = hi; + } + + if (k > 0) + PushBack(k); + + return *this; + } + + BigInteger& operator*=(uint32_t u) { + if (u == 0) return *this = 0; + if (u == 1) return *this; + if (*this == 1) return *this = u; + + uint64_t k = 0; + for (size_t i = 0; i < count_; i++) { + const uint64_t c = digits_[i] >> 32; + const uint64_t d = digits_[i] & 0xFFFFFFFF; + const uint64_t uc = u * c; + const uint64_t ud = u * d; + const uint64_t p0 = ud + k; + const uint64_t p1 = uc + (p0 >> 32); + digits_[i] = (p0 & 0xFFFFFFFF) | (p1 << 32); + k = p1 >> 32; + } + + if (k > 0) + PushBack(k); + + return *this; + } + + BigInteger& operator<<=(size_t shift) { + if (IsZero() || shift == 0) return *this; + + size_t offset = shift / kTypeBit; + size_t interShift = shift % kTypeBit; + RAPIDJSON_ASSERT(count_ + offset <= kCapacity); + + if (interShift == 0) { + std::memmove(digits_ + offset, digits_, count_ * sizeof(Type)); + count_ += offset; + } + else { + digits_[count_] = 0; + for (size_t i = count_; i > 0; i--) + digits_[i + offset] = (digits_[i] << interShift) | (digits_[i - 1] >> (kTypeBit - interShift)); + digits_[offset] = digits_[0] << interShift; + count_ += offset; + if (digits_[count_]) + count_++; + } + + std::memset(digits_, 0, offset * sizeof(Type)); + + return *this; + } + + bool operator==(const BigInteger& rhs) const { + return count_ == rhs.count_ && std::memcmp(digits_, rhs.digits_, count_ * sizeof(Type)) == 0; + } + + bool operator==(const Type rhs) const { + return count_ == 1 && digits_[0] == rhs; + } + + BigInteger& MultiplyPow5(unsigned exp) { + static const uint32_t kPow5[12] = { + 5, + 5 * 5, + 5 * 5 * 5, + 5 * 5 * 5 * 5, + 5 * 5 * 5 * 5 * 5, + 5 * 5 * 5 * 5 * 5 * 5, + 5 * 5 * 5 * 5 * 5 * 5 * 5, + 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5, + 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5, + 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5, + 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5, + 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 + }; + if (exp == 0) return *this; + for (; exp >= 27; exp -= 27) *this *= RAPIDJSON_UINT64_C2(0X6765C793, 0XFA10079D); // 5^27 + for (; exp >= 13; exp -= 13) *this *= static_cast(1220703125u); // 5^13 + if (exp > 0) *this *= kPow5[exp - 1]; + return *this; + } + + // Compute absolute difference of this and rhs. + // Assume this != rhs + bool Difference(const BigInteger& rhs, BigInteger* out) const { + int cmp = Compare(rhs); + RAPIDJSON_ASSERT(cmp != 0); + const BigInteger *a, *b; // Makes a > b + bool ret; + if (cmp < 0) { a = &rhs; b = this; ret = true; } + else { a = this; b = &rhs; ret = false; } + + Type borrow = 0; + for (size_t i = 0; i < a->count_; i++) { + Type d = a->digits_[i] - borrow; + if (i < b->count_) + d -= b->digits_[i]; + borrow = (d > a->digits_[i]) ? 1 : 0; + out->digits_[i] = d; + if (d != 0) + out->count_ = i + 1; + } + + return ret; + } + + int Compare(const BigInteger& rhs) const { + if (count_ != rhs.count_) + return count_ < rhs.count_ ? -1 : 1; + + for (size_t i = count_; i-- > 0;) + if (digits_[i] != rhs.digits_[i]) + return digits_[i] < rhs.digits_[i] ? -1 : 1; + + return 0; + } + + size_t GetCount() const { return count_; } + Type GetDigit(size_t index) const { RAPIDJSON_ASSERT(index < count_); return digits_[index]; } + bool IsZero() const { return count_ == 1 && digits_[0] == 0; } + +private: + template + void AppendDecimal64(const Ch* begin, const Ch* end) { + uint64_t u = ParseUint64(begin, end); + if (IsZero()) + *this = u; + else { + unsigned exp = static_cast(end - begin); + (MultiplyPow5(exp) <<= exp) += u; // *this = *this * 10^exp + u + } + } + + void PushBack(Type digit) { + RAPIDJSON_ASSERT(count_ < kCapacity); + digits_[count_++] = digit; + } + + template + static uint64_t ParseUint64(const Ch* begin, const Ch* end) { + uint64_t r = 0; + for (const Ch* p = begin; p != end; ++p) { + RAPIDJSON_ASSERT(*p >= Ch('0') && *p <= Ch('9')); + r = r * 10u + static_cast(*p - Ch('0')); + } + return r; + } + + // Assume a * b + k < 2^128 + static uint64_t MulAdd64(uint64_t a, uint64_t b, uint64_t k, uint64_t* outHigh) { +#if defined(_MSC_VER) && defined(_M_AMD64) + uint64_t low = _umul128(a, b, outHigh) + k; + if (low < k) + (*outHigh)++; + return low; +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && defined(__x86_64__) + __extension__ typedef unsigned __int128 uint128; + uint128 p = static_cast(a) * static_cast(b); + p += k; + *outHigh = static_cast(p >> 64); + return static_cast(p); +#else + const uint64_t a0 = a & 0xFFFFFFFF, a1 = a >> 32, b0 = b & 0xFFFFFFFF, b1 = b >> 32; + uint64_t x0 = a0 * b0, x1 = a0 * b1, x2 = a1 * b0, x3 = a1 * b1; + x1 += (x0 >> 32); // can't give carry + x1 += x2; + if (x1 < x2) + x3 += (static_cast(1) << 32); + uint64_t lo = (x1 << 32) + (x0 & 0xFFFFFFFF); + uint64_t hi = x3 + (x1 >> 32); + + lo += k; + if (lo < k) + hi++; + *outHigh = hi; + return lo; +#endif + } + + static const size_t kBitCount = 3328; // 64bit * 54 > 10^1000 + static const size_t kCapacity = kBitCount / sizeof(Type); + static const size_t kTypeBit = sizeof(Type) * 8; + + Type digits_[kCapacity]; + size_t count_; +}; + +} // namespace internal +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_BIGINTEGER_H_ diff --git a/include/rapidjson/internal/clzll.h b/include/rapidjson/internal/clzll.h new file mode 100644 index 0000000000..8fc5118aa4 --- /dev/null +++ b/include/rapidjson/internal/clzll.h @@ -0,0 +1,71 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_CLZLL_H_ +#define RAPIDJSON_CLZLL_H_ + +#include "../rapidjson.h" + +#if defined(_MSC_VER) && !defined(UNDER_CE) +#include +#if defined(_WIN64) +#pragma intrinsic(_BitScanReverse64) +#else +#pragma intrinsic(_BitScanReverse) +#endif +#endif + +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +inline uint32_t clzll(uint64_t x) { + // Passing 0 to __builtin_clzll is UB in GCC and results in an + // infinite loop in the software implementation. + RAPIDJSON_ASSERT(x != 0); + +#if defined(_MSC_VER) && !defined(UNDER_CE) + unsigned long r = 0; +#if defined(_WIN64) + _BitScanReverse64(&r, x); +#else + // Scan the high 32 bits. + if (_BitScanReverse(&r, static_cast(x >> 32))) + return 63 - (r + 32); + + // Scan the low 32 bits. + _BitScanReverse(&r, static_cast(x & 0xFFFFFFFF)); +#endif // _WIN64 + + return 63 - r; +#elif (defined(__GNUC__) && __GNUC__ >= 4) || RAPIDJSON_HAS_BUILTIN(__builtin_clzll) + // __builtin_clzll wrapper + return static_cast(__builtin_clzll(x)); +#else + // naive version + uint32_t r = 0; + while (!(x & (static_cast(1) << 63))) { + x <<= 1; + ++r; + } + + return r; +#endif // _MSC_VER +} + +#define RAPIDJSON_CLZLL RAPIDJSON_NAMESPACE::internal::clzll + +} // namespace internal +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_CLZLL_H_ diff --git a/include/rapidjson/internal/diyfp.h b/include/rapidjson/internal/diyfp.h new file mode 100644 index 0000000000..1f60fb60ca --- /dev/null +++ b/include/rapidjson/internal/diyfp.h @@ -0,0 +1,261 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +// This is a C++ header-only implementation of Grisu2 algorithm from the publication: +// Loitsch, Florian. "Printing floating-point numbers quickly and accurately with +// integers." ACM Sigplan Notices 45.6 (2010): 233-243. + +#ifndef RAPIDJSON_DIYFP_H_ +#define RAPIDJSON_DIYFP_H_ + +#include "../rapidjson.h" +#include "clzll.h" +#include + +#if defined(_MSC_VER) && defined(_M_AMD64) && !defined(__INTEL_COMPILER) +#include +#if !defined(_ARM64EC_) +#pragma intrinsic(_umul128) +#else +#pragma comment(lib,"softintrin") +#endif +#endif + +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +#ifdef __GNUC__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(effc++) +#endif + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(padded) +#endif + +struct DiyFp { + DiyFp() : f(), e() {} + + DiyFp(uint64_t fp, int exp) : f(fp), e(exp) {} + + explicit DiyFp(double d) { + union { + double d; + uint64_t u64; + } u = { d }; + + int biased_e = static_cast((u.u64 & kDpExponentMask) >> kDpSignificandSize); + uint64_t significand = (u.u64 & kDpSignificandMask); + if (biased_e != 0) { + f = significand + kDpHiddenBit; + e = biased_e - kDpExponentBias; + } + else { + f = significand; + e = kDpMinExponent + 1; + } + } + + DiyFp operator-(const DiyFp& rhs) const { + return DiyFp(f - rhs.f, e); + } + + DiyFp operator*(const DiyFp& rhs) const { +#if defined(_MSC_VER) && defined(_M_AMD64) + uint64_t h; + uint64_t l = _umul128(f, rhs.f, &h); + if (l & (uint64_t(1) << 63)) // rounding + h++; + return DiyFp(h, e + rhs.e + 64); +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && defined(__x86_64__) + __extension__ typedef unsigned __int128 uint128; + uint128 p = static_cast(f) * static_cast(rhs.f); + uint64_t h = static_cast(p >> 64); + uint64_t l = static_cast(p); + if (l & (uint64_t(1) << 63)) // rounding + h++; + return DiyFp(h, e + rhs.e + 64); +#else + const uint64_t M32 = 0xFFFFFFFF; + const uint64_t a = f >> 32; + const uint64_t b = f & M32; + const uint64_t c = rhs.f >> 32; + const uint64_t d = rhs.f & M32; + const uint64_t ac = a * c; + const uint64_t bc = b * c; + const uint64_t ad = a * d; + const uint64_t bd = b * d; + uint64_t tmp = (bd >> 32) + (ad & M32) + (bc & M32); + tmp += 1U << 31; /// mult_round + return DiyFp(ac + (ad >> 32) + (bc >> 32) + (tmp >> 32), e + rhs.e + 64); +#endif + } + + DiyFp Normalize() const { + int s = static_cast(clzll(f)); + return DiyFp(f << s, e - s); + } + + DiyFp NormalizeBoundary() const { + DiyFp res = *this; + while (!(res.f & (kDpHiddenBit << 1))) { + res.f <<= 1; + res.e--; + } + res.f <<= (kDiySignificandSize - kDpSignificandSize - 2); + res.e = res.e - (kDiySignificandSize - kDpSignificandSize - 2); + return res; + } + + void NormalizedBoundaries(DiyFp* minus, DiyFp* plus) const { + DiyFp pl = DiyFp((f << 1) + 1, e - 1).NormalizeBoundary(); + DiyFp mi = (f == kDpHiddenBit) ? DiyFp((f << 2) - 1, e - 2) : DiyFp((f << 1) - 1, e - 1); + mi.f <<= mi.e - pl.e; + mi.e = pl.e; + *plus = pl; + *minus = mi; + } + + double ToDouble() const { + union { + double d; + uint64_t u64; + }u; + RAPIDJSON_ASSERT(f <= kDpHiddenBit + kDpSignificandMask); + if (e < kDpDenormalExponent) { + // Underflow. + return 0.0; + } + if (e >= kDpMaxExponent) { + // Overflow. + return std::numeric_limits::infinity(); + } + const uint64_t be = (e == kDpDenormalExponent && (f & kDpHiddenBit) == 0) ? 0 : + static_cast(e + kDpExponentBias); + u.u64 = (f & kDpSignificandMask) | (be << kDpSignificandSize); + return u.d; + } + + static const int kDiySignificandSize = 64; + static const int kDpSignificandSize = 52; + static const int kDpExponentBias = 0x3FF + kDpSignificandSize; + static const int kDpMaxExponent = 0x7FF - kDpExponentBias; + static const int kDpMinExponent = -kDpExponentBias; + static const int kDpDenormalExponent = -kDpExponentBias + 1; + static const uint64_t kDpExponentMask = RAPIDJSON_UINT64_C2(0x7FF00000, 0x00000000); + static const uint64_t kDpSignificandMask = RAPIDJSON_UINT64_C2(0x000FFFFF, 0xFFFFFFFF); + static const uint64_t kDpHiddenBit = RAPIDJSON_UINT64_C2(0x00100000, 0x00000000); + + uint64_t f; + int e; +}; + +inline DiyFp GetCachedPowerByIndex(size_t index) { + // 10^-348, 10^-340, ..., 10^340 + static const uint64_t kCachedPowers_F[] = { + RAPIDJSON_UINT64_C2(0xfa8fd5a0, 0x081c0288), RAPIDJSON_UINT64_C2(0xbaaee17f, 0xa23ebf76), + RAPIDJSON_UINT64_C2(0x8b16fb20, 0x3055ac76), RAPIDJSON_UINT64_C2(0xcf42894a, 0x5dce35ea), + RAPIDJSON_UINT64_C2(0x9a6bb0aa, 0x55653b2d), RAPIDJSON_UINT64_C2(0xe61acf03, 0x3d1a45df), + RAPIDJSON_UINT64_C2(0xab70fe17, 0xc79ac6ca), RAPIDJSON_UINT64_C2(0xff77b1fc, 0xbebcdc4f), + RAPIDJSON_UINT64_C2(0xbe5691ef, 0x416bd60c), RAPIDJSON_UINT64_C2(0x8dd01fad, 0x907ffc3c), + RAPIDJSON_UINT64_C2(0xd3515c28, 0x31559a83), RAPIDJSON_UINT64_C2(0x9d71ac8f, 0xada6c9b5), + RAPIDJSON_UINT64_C2(0xea9c2277, 0x23ee8bcb), RAPIDJSON_UINT64_C2(0xaecc4991, 0x4078536d), + RAPIDJSON_UINT64_C2(0x823c1279, 0x5db6ce57), RAPIDJSON_UINT64_C2(0xc2109436, 0x4dfb5637), + RAPIDJSON_UINT64_C2(0x9096ea6f, 0x3848984f), RAPIDJSON_UINT64_C2(0xd77485cb, 0x25823ac7), + RAPIDJSON_UINT64_C2(0xa086cfcd, 0x97bf97f4), RAPIDJSON_UINT64_C2(0xef340a98, 0x172aace5), + RAPIDJSON_UINT64_C2(0xb23867fb, 0x2a35b28e), RAPIDJSON_UINT64_C2(0x84c8d4df, 0xd2c63f3b), + RAPIDJSON_UINT64_C2(0xc5dd4427, 0x1ad3cdba), RAPIDJSON_UINT64_C2(0x936b9fce, 0xbb25c996), + RAPIDJSON_UINT64_C2(0xdbac6c24, 0x7d62a584), RAPIDJSON_UINT64_C2(0xa3ab6658, 0x0d5fdaf6), + RAPIDJSON_UINT64_C2(0xf3e2f893, 0xdec3f126), RAPIDJSON_UINT64_C2(0xb5b5ada8, 0xaaff80b8), + RAPIDJSON_UINT64_C2(0x87625f05, 0x6c7c4a8b), RAPIDJSON_UINT64_C2(0xc9bcff60, 0x34c13053), + RAPIDJSON_UINT64_C2(0x964e858c, 0x91ba2655), RAPIDJSON_UINT64_C2(0xdff97724, 0x70297ebd), + RAPIDJSON_UINT64_C2(0xa6dfbd9f, 0xb8e5b88f), RAPIDJSON_UINT64_C2(0xf8a95fcf, 0x88747d94), + RAPIDJSON_UINT64_C2(0xb9447093, 0x8fa89bcf), RAPIDJSON_UINT64_C2(0x8a08f0f8, 0xbf0f156b), + RAPIDJSON_UINT64_C2(0xcdb02555, 0x653131b6), RAPIDJSON_UINT64_C2(0x993fe2c6, 0xd07b7fac), + RAPIDJSON_UINT64_C2(0xe45c10c4, 0x2a2b3b06), RAPIDJSON_UINT64_C2(0xaa242499, 0x697392d3), + RAPIDJSON_UINT64_C2(0xfd87b5f2, 0x8300ca0e), RAPIDJSON_UINT64_C2(0xbce50864, 0x92111aeb), + RAPIDJSON_UINT64_C2(0x8cbccc09, 0x6f5088cc), RAPIDJSON_UINT64_C2(0xd1b71758, 0xe219652c), + RAPIDJSON_UINT64_C2(0x9c400000, 0x00000000), RAPIDJSON_UINT64_C2(0xe8d4a510, 0x00000000), + RAPIDJSON_UINT64_C2(0xad78ebc5, 0xac620000), RAPIDJSON_UINT64_C2(0x813f3978, 0xf8940984), + RAPIDJSON_UINT64_C2(0xc097ce7b, 0xc90715b3), RAPIDJSON_UINT64_C2(0x8f7e32ce, 0x7bea5c70), + RAPIDJSON_UINT64_C2(0xd5d238a4, 0xabe98068), RAPIDJSON_UINT64_C2(0x9f4f2726, 0x179a2245), + RAPIDJSON_UINT64_C2(0xed63a231, 0xd4c4fb27), RAPIDJSON_UINT64_C2(0xb0de6538, 0x8cc8ada8), + RAPIDJSON_UINT64_C2(0x83c7088e, 0x1aab65db), RAPIDJSON_UINT64_C2(0xc45d1df9, 0x42711d9a), + RAPIDJSON_UINT64_C2(0x924d692c, 0xa61be758), RAPIDJSON_UINT64_C2(0xda01ee64, 0x1a708dea), + RAPIDJSON_UINT64_C2(0xa26da399, 0x9aef774a), RAPIDJSON_UINT64_C2(0xf209787b, 0xb47d6b85), + RAPIDJSON_UINT64_C2(0xb454e4a1, 0x79dd1877), RAPIDJSON_UINT64_C2(0x865b8692, 0x5b9bc5c2), + RAPIDJSON_UINT64_C2(0xc83553c5, 0xc8965d3d), RAPIDJSON_UINT64_C2(0x952ab45c, 0xfa97a0b3), + RAPIDJSON_UINT64_C2(0xde469fbd, 0x99a05fe3), RAPIDJSON_UINT64_C2(0xa59bc234, 0xdb398c25), + RAPIDJSON_UINT64_C2(0xf6c69a72, 0xa3989f5c), RAPIDJSON_UINT64_C2(0xb7dcbf53, 0x54e9bece), + RAPIDJSON_UINT64_C2(0x88fcf317, 0xf22241e2), RAPIDJSON_UINT64_C2(0xcc20ce9b, 0xd35c78a5), + RAPIDJSON_UINT64_C2(0x98165af3, 0x7b2153df), RAPIDJSON_UINT64_C2(0xe2a0b5dc, 0x971f303a), + RAPIDJSON_UINT64_C2(0xa8d9d153, 0x5ce3b396), RAPIDJSON_UINT64_C2(0xfb9b7cd9, 0xa4a7443c), + RAPIDJSON_UINT64_C2(0xbb764c4c, 0xa7a44410), RAPIDJSON_UINT64_C2(0x8bab8eef, 0xb6409c1a), + RAPIDJSON_UINT64_C2(0xd01fef10, 0xa657842c), RAPIDJSON_UINT64_C2(0x9b10a4e5, 0xe9913129), + RAPIDJSON_UINT64_C2(0xe7109bfb, 0xa19c0c9d), RAPIDJSON_UINT64_C2(0xac2820d9, 0x623bf429), + RAPIDJSON_UINT64_C2(0x80444b5e, 0x7aa7cf85), RAPIDJSON_UINT64_C2(0xbf21e440, 0x03acdd2d), + RAPIDJSON_UINT64_C2(0x8e679c2f, 0x5e44ff8f), RAPIDJSON_UINT64_C2(0xd433179d, 0x9c8cb841), + RAPIDJSON_UINT64_C2(0x9e19db92, 0xb4e31ba9), RAPIDJSON_UINT64_C2(0xeb96bf6e, 0xbadf77d9), + RAPIDJSON_UINT64_C2(0xaf87023b, 0x9bf0ee6b) + }; + static const int16_t kCachedPowers_E[] = { + -1220, -1193, -1166, -1140, -1113, -1087, -1060, -1034, -1007, -980, + -954, -927, -901, -874, -847, -821, -794, -768, -741, -715, + -688, -661, -635, -608, -582, -555, -529, -502, -475, -449, + -422, -396, -369, -343, -316, -289, -263, -236, -210, -183, + -157, -130, -103, -77, -50, -24, 3, 30, 56, 83, + 109, 136, 162, 189, 216, 242, 269, 295, 322, 348, + 375, 402, 428, 455, 481, 508, 534, 561, 588, 614, + 641, 667, 694, 720, 747, 774, 800, 827, 853, 880, + 907, 933, 960, 986, 1013, 1039, 1066 + }; + RAPIDJSON_ASSERT(index < 87); + return DiyFp(kCachedPowers_F[index], kCachedPowers_E[index]); +} + +inline DiyFp GetCachedPower(int e, int* K) { + + //int k = static_cast(ceil((-61 - e) * 0.30102999566398114)) + 374; + double dk = (-61 - e) * 0.30102999566398114 + 347; // dk must be positive, so can do ceiling in positive + int k = static_cast(dk); + if (dk - k > 0.0) + k++; + + unsigned index = static_cast((k >> 3) + 1); + *K = -(-348 + static_cast(index << 3)); // decimal exponent no need lookup table + + return GetCachedPowerByIndex(index); +} + +inline DiyFp GetCachedPower10(int exp, int *outExp) { + RAPIDJSON_ASSERT(exp >= -348); + unsigned index = static_cast(exp + 348) / 8u; + *outExp = -348 + static_cast(index) * 8; + return GetCachedPowerByIndex(index); +} + +#ifdef __GNUC__ +RAPIDJSON_DIAG_POP +#endif + +#ifdef __clang__ +RAPIDJSON_DIAG_POP +RAPIDJSON_DIAG_OFF(padded) +#endif + +} // namespace internal +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_DIYFP_H_ diff --git a/include/rapidjson/internal/dtoa.h b/include/rapidjson/internal/dtoa.h new file mode 100644 index 0000000000..cd456721a7 --- /dev/null +++ b/include/rapidjson/internal/dtoa.h @@ -0,0 +1,249 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +// This is a C++ header-only implementation of Grisu2 algorithm from the publication: +// Loitsch, Florian. "Printing floating-point numbers quickly and accurately with +// integers." ACM Sigplan Notices 45.6 (2010): 233-243. + +#ifndef RAPIDJSON_DTOA_ +#define RAPIDJSON_DTOA_ + +#include "itoa.h" // GetDigitsLut() +#include "diyfp.h" +#include "ieee754.h" + +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +#ifdef __GNUC__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(effc++) +RAPIDJSON_DIAG_OFF(array-bounds) // some gcc versions generate wrong warnings https://gcc.gnu.org/bugzilla/show_bug.cgi?id=59124 +#endif + +inline void GrisuRound(char* buffer, int len, uint64_t delta, uint64_t rest, uint64_t ten_kappa, uint64_t wp_w) { + while (rest < wp_w && delta - rest >= ten_kappa && + (rest + ten_kappa < wp_w || /// closer + wp_w - rest > rest + ten_kappa - wp_w)) { + buffer[len - 1]--; + rest += ten_kappa; + } +} + +inline int CountDecimalDigit32(uint32_t n) { + // Simple pure C++ implementation was faster than __builtin_clz version in this situation. + if (n < 10) return 1; + if (n < 100) return 2; + if (n < 1000) return 3; + if (n < 10000) return 4; + if (n < 100000) return 5; + if (n < 1000000) return 6; + if (n < 10000000) return 7; + if (n < 100000000) return 8; + // Will not reach 10 digits in DigitGen() + //if (n < 1000000000) return 9; + //return 10; + return 9; +} + +inline void DigitGen(const DiyFp& W, const DiyFp& Mp, uint64_t delta, char* buffer, int* len, int* K) { + static const uint64_t kPow10[] = { 1ULL, 10ULL, 100ULL, 1000ULL, 10000ULL, 100000ULL, 1000000ULL, 10000000ULL, 100000000ULL, + 1000000000ULL, 10000000000ULL, 100000000000ULL, 1000000000000ULL, + 10000000000000ULL, 100000000000000ULL, 1000000000000000ULL, + 10000000000000000ULL, 100000000000000000ULL, 1000000000000000000ULL, + 10000000000000000000ULL }; + const DiyFp one(uint64_t(1) << -Mp.e, Mp.e); + const DiyFp wp_w = Mp - W; + uint32_t p1 = static_cast(Mp.f >> -one.e); + uint64_t p2 = Mp.f & (one.f - 1); + int kappa = CountDecimalDigit32(p1); // kappa in [0, 9] + *len = 0; + + while (kappa > 0) { + uint32_t d = 0; + switch (kappa) { + case 9: d = p1 / 100000000; p1 %= 100000000; break; + case 8: d = p1 / 10000000; p1 %= 10000000; break; + case 7: d = p1 / 1000000; p1 %= 1000000; break; + case 6: d = p1 / 100000; p1 %= 100000; break; + case 5: d = p1 / 10000; p1 %= 10000; break; + case 4: d = p1 / 1000; p1 %= 1000; break; + case 3: d = p1 / 100; p1 %= 100; break; + case 2: d = p1 / 10; p1 %= 10; break; + case 1: d = p1; p1 = 0; break; + default:; + } + if (d || *len) + buffer[(*len)++] = static_cast('0' + static_cast(d)); + kappa--; + uint64_t tmp = (static_cast(p1) << -one.e) + p2; + if (tmp <= delta) { + *K += kappa; + GrisuRound(buffer, *len, delta, tmp, kPow10[kappa] << -one.e, wp_w.f); + return; + } + } + + // kappa = 0 + for (;;) { + p2 *= 10; + delta *= 10; + char d = static_cast(p2 >> -one.e); + if (d || *len) + buffer[(*len)++] = static_cast('0' + d); + p2 &= one.f - 1; + kappa--; + if (p2 < delta) { + *K += kappa; + int index = -kappa; + GrisuRound(buffer, *len, delta, p2, one.f, wp_w.f * (index < 20 ? kPow10[index] : 0)); + return; + } + } +} + +inline void Grisu2(double value, char* buffer, int* length, int* K) { + const DiyFp v(value); + DiyFp w_m, w_p; + v.NormalizedBoundaries(&w_m, &w_p); + + const DiyFp c_mk = GetCachedPower(w_p.e, K); + const DiyFp W = v.Normalize() * c_mk; + DiyFp Wp = w_p * c_mk; + DiyFp Wm = w_m * c_mk; + Wm.f++; + Wp.f--; + DigitGen(W, Wp, Wp.f - Wm.f, buffer, length, K); +} + +inline char* WriteExponent(int K, char* buffer) { + if (K < 0) { + *buffer++ = '-'; + K = -K; + } + + if (K >= 100) { + *buffer++ = static_cast('0' + static_cast(K / 100)); + K %= 100; + const char* d = GetDigitsLut() + K * 2; + *buffer++ = d[0]; + *buffer++ = d[1]; + } + else if (K >= 10) { + const char* d = GetDigitsLut() + K * 2; + *buffer++ = d[0]; + *buffer++ = d[1]; + } + else + *buffer++ = static_cast('0' + static_cast(K)); + + return buffer; +} + +inline char* Prettify(char* buffer, int length, int k, int maxDecimalPlaces) { + const int kk = length + k; // 10^(kk-1) <= v < 10^kk + + if (0 <= k && kk <= 21) { + // 1234e7 -> 12340000000 + for (int i = length; i < kk; i++) + buffer[i] = '0'; + buffer[kk] = '.'; + buffer[kk + 1] = '0'; + return &buffer[kk + 2]; + } + else if (0 < kk && kk <= 21) { + // 1234e-2 -> 12.34 + std::memmove(&buffer[kk + 1], &buffer[kk], static_cast(length - kk)); + buffer[kk] = '.'; + if (0 > k + maxDecimalPlaces) { + // When maxDecimalPlaces = 2, 1.2345 -> 1.23, 1.102 -> 1.1 + // Remove extra trailing zeros (at least one) after truncation. + for (int i = kk + maxDecimalPlaces; i > kk + 1; i--) + if (buffer[i] != '0') + return &buffer[i + 1]; + return &buffer[kk + 2]; // Reserve one zero + } + else + return &buffer[length + 1]; + } + else if (-6 < kk && kk <= 0) { + // 1234e-6 -> 0.001234 + const int offset = 2 - kk; + std::memmove(&buffer[offset], &buffer[0], static_cast(length)); + buffer[0] = '0'; + buffer[1] = '.'; + for (int i = 2; i < offset; i++) + buffer[i] = '0'; + if (length - kk > maxDecimalPlaces) { + // When maxDecimalPlaces = 2, 0.123 -> 0.12, 0.102 -> 0.1 + // Remove extra trailing zeros (at least one) after truncation. + for (int i = maxDecimalPlaces + 1; i > 2; i--) + if (buffer[i] != '0') + return &buffer[i + 1]; + return &buffer[3]; // Reserve one zero + } + else + return &buffer[length + offset]; + } + else if (kk < -maxDecimalPlaces) { + // Truncate to zero + buffer[0] = '0'; + buffer[1] = '.'; + buffer[2] = '0'; + return &buffer[3]; + } + else if (length == 1) { + // 1e30 + buffer[1] = 'e'; + return WriteExponent(kk - 1, &buffer[2]); + } + else { + // 1234e30 -> 1.234e33 + std::memmove(&buffer[2], &buffer[1], static_cast(length - 1)); + buffer[1] = '.'; + buffer[length + 1] = 'e'; + return WriteExponent(kk - 1, &buffer[0 + length + 2]); + } +} + +inline char* dtoa(double value, char* buffer, int maxDecimalPlaces = 324) { + RAPIDJSON_ASSERT(maxDecimalPlaces >= 1); + Double d(value); + if (d.IsZero()) { + if (d.Sign()) + *buffer++ = '-'; // -0.0, Issue #289 + buffer[0] = '0'; + buffer[1] = '.'; + buffer[2] = '0'; + return &buffer[3]; + } + else { + if (value < 0) { + *buffer++ = '-'; + value = -value; + } + int length, K; + Grisu2(value, buffer, &length, &K); + return Prettify(buffer, length, K, maxDecimalPlaces); + } +} + +#ifdef __GNUC__ +RAPIDJSON_DIAG_POP +#endif + +} // namespace internal +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_DTOA_ diff --git a/include/rapidjson/internal/ieee754.h b/include/rapidjson/internal/ieee754.h new file mode 100644 index 0000000000..68c9e96649 --- /dev/null +++ b/include/rapidjson/internal/ieee754.h @@ -0,0 +1,78 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_IEEE754_ +#define RAPIDJSON_IEEE754_ + +#include "../rapidjson.h" + +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +class Double { +public: + Double() {} + Double(double d) : d_(d) {} + Double(uint64_t u) : u_(u) {} + + double Value() const { return d_; } + uint64_t Uint64Value() const { return u_; } + + double NextPositiveDouble() const { + RAPIDJSON_ASSERT(!Sign()); + return Double(u_ + 1).Value(); + } + + bool Sign() const { return (u_ & kSignMask) != 0; } + uint64_t Significand() const { return u_ & kSignificandMask; } + int Exponent() const { return static_cast(((u_ & kExponentMask) >> kSignificandSize) - kExponentBias); } + + bool IsNan() const { return (u_ & kExponentMask) == kExponentMask && Significand() != 0; } + bool IsInf() const { return (u_ & kExponentMask) == kExponentMask && Significand() == 0; } + bool IsNanOrInf() const { return (u_ & kExponentMask) == kExponentMask; } + bool IsNormal() const { return (u_ & kExponentMask) != 0 || Significand() == 0; } + bool IsZero() const { return (u_ & (kExponentMask | kSignificandMask)) == 0; } + + uint64_t IntegerSignificand() const { return IsNormal() ? Significand() | kHiddenBit : Significand(); } + int IntegerExponent() const { return (IsNormal() ? Exponent() : kDenormalExponent) - kSignificandSize; } + uint64_t ToBias() const { return (u_ & kSignMask) ? ~u_ + 1 : u_ | kSignMask; } + + static int EffectiveSignificandSize(int order) { + if (order >= -1021) + return 53; + else if (order <= -1074) + return 0; + else + return order + 1074; + } + +private: + static const int kSignificandSize = 52; + static const int kExponentBias = 0x3FF; + static const int kDenormalExponent = 1 - kExponentBias; + static const uint64_t kSignMask = RAPIDJSON_UINT64_C2(0x80000000, 0x00000000); + static const uint64_t kExponentMask = RAPIDJSON_UINT64_C2(0x7FF00000, 0x00000000); + static const uint64_t kSignificandMask = RAPIDJSON_UINT64_C2(0x000FFFFF, 0xFFFFFFFF); + static const uint64_t kHiddenBit = RAPIDJSON_UINT64_C2(0x00100000, 0x00000000); + + union { + double d_; + uint64_t u_; + }; +}; + +} // namespace internal +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_IEEE754_ diff --git a/include/rapidjson/internal/itoa.h b/include/rapidjson/internal/itoa.h new file mode 100644 index 0000000000..9fe8c932ff --- /dev/null +++ b/include/rapidjson/internal/itoa.h @@ -0,0 +1,308 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_ITOA_ +#define RAPIDJSON_ITOA_ + +#include "../rapidjson.h" + +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +inline const char* GetDigitsLut() { + static const char cDigitsLut[200] = { + '0','0','0','1','0','2','0','3','0','4','0','5','0','6','0','7','0','8','0','9', + '1','0','1','1','1','2','1','3','1','4','1','5','1','6','1','7','1','8','1','9', + '2','0','2','1','2','2','2','3','2','4','2','5','2','6','2','7','2','8','2','9', + '3','0','3','1','3','2','3','3','3','4','3','5','3','6','3','7','3','8','3','9', + '4','0','4','1','4','2','4','3','4','4','4','5','4','6','4','7','4','8','4','9', + '5','0','5','1','5','2','5','3','5','4','5','5','5','6','5','7','5','8','5','9', + '6','0','6','1','6','2','6','3','6','4','6','5','6','6','6','7','6','8','6','9', + '7','0','7','1','7','2','7','3','7','4','7','5','7','6','7','7','7','8','7','9', + '8','0','8','1','8','2','8','3','8','4','8','5','8','6','8','7','8','8','8','9', + '9','0','9','1','9','2','9','3','9','4','9','5','9','6','9','7','9','8','9','9' + }; + return cDigitsLut; +} + +inline char* u32toa(uint32_t value, char* buffer) { + RAPIDJSON_ASSERT(buffer != 0); + + const char* cDigitsLut = GetDigitsLut(); + + if (value < 10000) { + const uint32_t d1 = (value / 100) << 1; + const uint32_t d2 = (value % 100) << 1; + + if (value >= 1000) + *buffer++ = cDigitsLut[d1]; + if (value >= 100) + *buffer++ = cDigitsLut[d1 + 1]; + if (value >= 10) + *buffer++ = cDigitsLut[d2]; + *buffer++ = cDigitsLut[d2 + 1]; + } + else if (value < 100000000) { + // value = bbbbcccc + const uint32_t b = value / 10000; + const uint32_t c = value % 10000; + + const uint32_t d1 = (b / 100) << 1; + const uint32_t d2 = (b % 100) << 1; + + const uint32_t d3 = (c / 100) << 1; + const uint32_t d4 = (c % 100) << 1; + + if (value >= 10000000) + *buffer++ = cDigitsLut[d1]; + if (value >= 1000000) + *buffer++ = cDigitsLut[d1 + 1]; + if (value >= 100000) + *buffer++ = cDigitsLut[d2]; + *buffer++ = cDigitsLut[d2 + 1]; + + *buffer++ = cDigitsLut[d3]; + *buffer++ = cDigitsLut[d3 + 1]; + *buffer++ = cDigitsLut[d4]; + *buffer++ = cDigitsLut[d4 + 1]; + } + else { + // value = aabbbbcccc in decimal + + const uint32_t a = value / 100000000; // 1 to 42 + value %= 100000000; + + if (a >= 10) { + const unsigned i = a << 1; + *buffer++ = cDigitsLut[i]; + *buffer++ = cDigitsLut[i + 1]; + } + else + *buffer++ = static_cast('0' + static_cast(a)); + + const uint32_t b = value / 10000; // 0 to 9999 + const uint32_t c = value % 10000; // 0 to 9999 + + const uint32_t d1 = (b / 100) << 1; + const uint32_t d2 = (b % 100) << 1; + + const uint32_t d3 = (c / 100) << 1; + const uint32_t d4 = (c % 100) << 1; + + *buffer++ = cDigitsLut[d1]; + *buffer++ = cDigitsLut[d1 + 1]; + *buffer++ = cDigitsLut[d2]; + *buffer++ = cDigitsLut[d2 + 1]; + *buffer++ = cDigitsLut[d3]; + *buffer++ = cDigitsLut[d3 + 1]; + *buffer++ = cDigitsLut[d4]; + *buffer++ = cDigitsLut[d4 + 1]; + } + return buffer; +} + +inline char* i32toa(int32_t value, char* buffer) { + RAPIDJSON_ASSERT(buffer != 0); + uint32_t u = static_cast(value); + if (value < 0) { + *buffer++ = '-'; + u = ~u + 1; + } + + return u32toa(u, buffer); +} + +inline char* u64toa(uint64_t value, char* buffer) { + RAPIDJSON_ASSERT(buffer != 0); + const char* cDigitsLut = GetDigitsLut(); + const uint64_t kTen8 = 100000000; + const uint64_t kTen9 = kTen8 * 10; + const uint64_t kTen10 = kTen8 * 100; + const uint64_t kTen11 = kTen8 * 1000; + const uint64_t kTen12 = kTen8 * 10000; + const uint64_t kTen13 = kTen8 * 100000; + const uint64_t kTen14 = kTen8 * 1000000; + const uint64_t kTen15 = kTen8 * 10000000; + const uint64_t kTen16 = kTen8 * kTen8; + + if (value < kTen8) { + uint32_t v = static_cast(value); + if (v < 10000) { + const uint32_t d1 = (v / 100) << 1; + const uint32_t d2 = (v % 100) << 1; + + if (v >= 1000) + *buffer++ = cDigitsLut[d1]; + if (v >= 100) + *buffer++ = cDigitsLut[d1 + 1]; + if (v >= 10) + *buffer++ = cDigitsLut[d2]; + *buffer++ = cDigitsLut[d2 + 1]; + } + else { + // value = bbbbcccc + const uint32_t b = v / 10000; + const uint32_t c = v % 10000; + + const uint32_t d1 = (b / 100) << 1; + const uint32_t d2 = (b % 100) << 1; + + const uint32_t d3 = (c / 100) << 1; + const uint32_t d4 = (c % 100) << 1; + + if (value >= 10000000) + *buffer++ = cDigitsLut[d1]; + if (value >= 1000000) + *buffer++ = cDigitsLut[d1 + 1]; + if (value >= 100000) + *buffer++ = cDigitsLut[d2]; + *buffer++ = cDigitsLut[d2 + 1]; + + *buffer++ = cDigitsLut[d3]; + *buffer++ = cDigitsLut[d3 + 1]; + *buffer++ = cDigitsLut[d4]; + *buffer++ = cDigitsLut[d4 + 1]; + } + } + else if (value < kTen16) { + const uint32_t v0 = static_cast(value / kTen8); + const uint32_t v1 = static_cast(value % kTen8); + + const uint32_t b0 = v0 / 10000; + const uint32_t c0 = v0 % 10000; + + const uint32_t d1 = (b0 / 100) << 1; + const uint32_t d2 = (b0 % 100) << 1; + + const uint32_t d3 = (c0 / 100) << 1; + const uint32_t d4 = (c0 % 100) << 1; + + const uint32_t b1 = v1 / 10000; + const uint32_t c1 = v1 % 10000; + + const uint32_t d5 = (b1 / 100) << 1; + const uint32_t d6 = (b1 % 100) << 1; + + const uint32_t d7 = (c1 / 100) << 1; + const uint32_t d8 = (c1 % 100) << 1; + + if (value >= kTen15) + *buffer++ = cDigitsLut[d1]; + if (value >= kTen14) + *buffer++ = cDigitsLut[d1 + 1]; + if (value >= kTen13) + *buffer++ = cDigitsLut[d2]; + if (value >= kTen12) + *buffer++ = cDigitsLut[d2 + 1]; + if (value >= kTen11) + *buffer++ = cDigitsLut[d3]; + if (value >= kTen10) + *buffer++ = cDigitsLut[d3 + 1]; + if (value >= kTen9) + *buffer++ = cDigitsLut[d4]; + + *buffer++ = cDigitsLut[d4 + 1]; + *buffer++ = cDigitsLut[d5]; + *buffer++ = cDigitsLut[d5 + 1]; + *buffer++ = cDigitsLut[d6]; + *buffer++ = cDigitsLut[d6 + 1]; + *buffer++ = cDigitsLut[d7]; + *buffer++ = cDigitsLut[d7 + 1]; + *buffer++ = cDigitsLut[d8]; + *buffer++ = cDigitsLut[d8 + 1]; + } + else { + const uint32_t a = static_cast(value / kTen16); // 1 to 1844 + value %= kTen16; + + if (a < 10) + *buffer++ = static_cast('0' + static_cast(a)); + else if (a < 100) { + const uint32_t i = a << 1; + *buffer++ = cDigitsLut[i]; + *buffer++ = cDigitsLut[i + 1]; + } + else if (a < 1000) { + *buffer++ = static_cast('0' + static_cast(a / 100)); + + const uint32_t i = (a % 100) << 1; + *buffer++ = cDigitsLut[i]; + *buffer++ = cDigitsLut[i + 1]; + } + else { + const uint32_t i = (a / 100) << 1; + const uint32_t j = (a % 100) << 1; + *buffer++ = cDigitsLut[i]; + *buffer++ = cDigitsLut[i + 1]; + *buffer++ = cDigitsLut[j]; + *buffer++ = cDigitsLut[j + 1]; + } + + const uint32_t v0 = static_cast(value / kTen8); + const uint32_t v1 = static_cast(value % kTen8); + + const uint32_t b0 = v0 / 10000; + const uint32_t c0 = v0 % 10000; + + const uint32_t d1 = (b0 / 100) << 1; + const uint32_t d2 = (b0 % 100) << 1; + + const uint32_t d3 = (c0 / 100) << 1; + const uint32_t d4 = (c0 % 100) << 1; + + const uint32_t b1 = v1 / 10000; + const uint32_t c1 = v1 % 10000; + + const uint32_t d5 = (b1 / 100) << 1; + const uint32_t d6 = (b1 % 100) << 1; + + const uint32_t d7 = (c1 / 100) << 1; + const uint32_t d8 = (c1 % 100) << 1; + + *buffer++ = cDigitsLut[d1]; + *buffer++ = cDigitsLut[d1 + 1]; + *buffer++ = cDigitsLut[d2]; + *buffer++ = cDigitsLut[d2 + 1]; + *buffer++ = cDigitsLut[d3]; + *buffer++ = cDigitsLut[d3 + 1]; + *buffer++ = cDigitsLut[d4]; + *buffer++ = cDigitsLut[d4 + 1]; + *buffer++ = cDigitsLut[d5]; + *buffer++ = cDigitsLut[d5 + 1]; + *buffer++ = cDigitsLut[d6]; + *buffer++ = cDigitsLut[d6 + 1]; + *buffer++ = cDigitsLut[d7]; + *buffer++ = cDigitsLut[d7 + 1]; + *buffer++ = cDigitsLut[d8]; + *buffer++ = cDigitsLut[d8 + 1]; + } + + return buffer; +} + +inline char* i64toa(int64_t value, char* buffer) { + RAPIDJSON_ASSERT(buffer != 0); + uint64_t u = static_cast(value); + if (value < 0) { + *buffer++ = '-'; + u = ~u + 1; + } + + return u64toa(u, buffer); +} + +} // namespace internal +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_ITOA_ diff --git a/include/rapidjson/internal/meta.h b/include/rapidjson/internal/meta.h new file mode 100644 index 0000000000..27092dc0d6 --- /dev/null +++ b/include/rapidjson/internal/meta.h @@ -0,0 +1,186 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_INTERNAL_META_H_ +#define RAPIDJSON_INTERNAL_META_H_ + +#include "../rapidjson.h" + +#ifdef __GNUC__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(effc++) +#endif + +#if defined(_MSC_VER) && !defined(__clang__) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(6334) +#endif + +#if RAPIDJSON_HAS_CXX11_TYPETRAITS +#include +#endif + +//@cond RAPIDJSON_INTERNAL +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +// Helper to wrap/convert arbitrary types to void, useful for arbitrary type matching +template struct Void { typedef void Type; }; + +/////////////////////////////////////////////////////////////////////////////// +// BoolType, TrueType, FalseType +// +template struct BoolType { + static const bool Value = Cond; + typedef BoolType Type; +}; +typedef BoolType TrueType; +typedef BoolType FalseType; + + +/////////////////////////////////////////////////////////////////////////////// +// SelectIf, BoolExpr, NotExpr, AndExpr, OrExpr +// + +template struct SelectIfImpl { template struct Apply { typedef T1 Type; }; }; +template <> struct SelectIfImpl { template struct Apply { typedef T2 Type; }; }; +template struct SelectIfCond : SelectIfImpl::template Apply {}; +template struct SelectIf : SelectIfCond {}; + +template struct AndExprCond : FalseType {}; +template <> struct AndExprCond : TrueType {}; +template struct OrExprCond : TrueType {}; +template <> struct OrExprCond : FalseType {}; + +template struct BoolExpr : SelectIf::Type {}; +template struct NotExpr : SelectIf::Type {}; +template struct AndExpr : AndExprCond::Type {}; +template struct OrExpr : OrExprCond::Type {}; + + +/////////////////////////////////////////////////////////////////////////////// +// AddConst, MaybeAddConst, RemoveConst +template struct AddConst { typedef const T Type; }; +template struct MaybeAddConst : SelectIfCond {}; +template struct RemoveConst { typedef T Type; }; +template struct RemoveConst { typedef T Type; }; + + +/////////////////////////////////////////////////////////////////////////////// +// IsSame, IsConst, IsMoreConst, IsPointer +// +template struct IsSame : FalseType {}; +template struct IsSame : TrueType {}; + +template struct IsConst : FalseType {}; +template struct IsConst : TrueType {}; + +template +struct IsMoreConst + : AndExpr::Type, typename RemoveConst::Type>, + BoolType::Value >= IsConst::Value> >::Type {}; + +template struct IsPointer : FalseType {}; +template struct IsPointer : TrueType {}; + +/////////////////////////////////////////////////////////////////////////////// +// IsBaseOf +// +#if RAPIDJSON_HAS_CXX11_TYPETRAITS + +template struct IsBaseOf + : BoolType< ::std::is_base_of::value> {}; + +#else // simplified version adopted from Boost + +template struct IsBaseOfImpl { + RAPIDJSON_STATIC_ASSERT(sizeof(B) != 0); + RAPIDJSON_STATIC_ASSERT(sizeof(D) != 0); + + typedef char (&Yes)[1]; + typedef char (&No) [2]; + + template + static Yes Check(const D*, T); + static No Check(const B*, int); + + struct Host { + operator const B*() const; + operator const D*(); + }; + + enum { Value = (sizeof(Check(Host(), 0)) == sizeof(Yes)) }; +}; + +template struct IsBaseOf + : OrExpr, BoolExpr > >::Type {}; + +#endif // RAPIDJSON_HAS_CXX11_TYPETRAITS + + +////////////////////////////////////////////////////////////////////////// +// EnableIf / DisableIf +// +template struct EnableIfCond { typedef T Type; }; +template struct EnableIfCond { /* empty */ }; + +template struct DisableIfCond { typedef T Type; }; +template struct DisableIfCond { /* empty */ }; + +template +struct EnableIf : EnableIfCond {}; + +template +struct DisableIf : DisableIfCond {}; + +// SFINAE helpers +struct SfinaeTag {}; +template struct RemoveSfinaeTag; +template struct RemoveSfinaeTag { typedef T Type; }; + +#define RAPIDJSON_REMOVEFPTR_(type) \ + typename ::RAPIDJSON_NAMESPACE::internal::RemoveSfinaeTag \ + < ::RAPIDJSON_NAMESPACE::internal::SfinaeTag&(*) type>::Type + +#define RAPIDJSON_ENABLEIF(cond) \ + typename ::RAPIDJSON_NAMESPACE::internal::EnableIf \ + ::Type * = NULL + +#define RAPIDJSON_DISABLEIF(cond) \ + typename ::RAPIDJSON_NAMESPACE::internal::DisableIf \ + ::Type * = NULL + +#define RAPIDJSON_ENABLEIF_RETURN(cond,returntype) \ + typename ::RAPIDJSON_NAMESPACE::internal::EnableIf \ + ::Type + +#define RAPIDJSON_DISABLEIF_RETURN(cond,returntype) \ + typename ::RAPIDJSON_NAMESPACE::internal::DisableIf \ + ::Type + +} // namespace internal +RAPIDJSON_NAMESPACE_END +//@endcond + +#if defined(_MSC_VER) && !defined(__clang__) +RAPIDJSON_DIAG_POP +#endif + +#ifdef __GNUC__ +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_INTERNAL_META_H_ diff --git a/include/rapidjson/internal/pow10.h b/include/rapidjson/internal/pow10.h new file mode 100644 index 0000000000..eae1a43ed1 --- /dev/null +++ b/include/rapidjson/internal/pow10.h @@ -0,0 +1,55 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_POW10_ +#define RAPIDJSON_POW10_ + +#include "../rapidjson.h" + +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +//! Computes integer powers of 10 in double (10.0^n). +/*! This function uses lookup table for fast and accurate results. + \param n non-negative exponent. Must <= 308. + \return 10.0^n +*/ +inline double Pow10(int n) { + static const double e[] = { // 1e-0...1e308: 309 * 8 bytes = 2472 bytes + 1e+0, + 1e+1, 1e+2, 1e+3, 1e+4, 1e+5, 1e+6, 1e+7, 1e+8, 1e+9, 1e+10, 1e+11, 1e+12, 1e+13, 1e+14, 1e+15, 1e+16, 1e+17, 1e+18, 1e+19, 1e+20, + 1e+21, 1e+22, 1e+23, 1e+24, 1e+25, 1e+26, 1e+27, 1e+28, 1e+29, 1e+30, 1e+31, 1e+32, 1e+33, 1e+34, 1e+35, 1e+36, 1e+37, 1e+38, 1e+39, 1e+40, + 1e+41, 1e+42, 1e+43, 1e+44, 1e+45, 1e+46, 1e+47, 1e+48, 1e+49, 1e+50, 1e+51, 1e+52, 1e+53, 1e+54, 1e+55, 1e+56, 1e+57, 1e+58, 1e+59, 1e+60, + 1e+61, 1e+62, 1e+63, 1e+64, 1e+65, 1e+66, 1e+67, 1e+68, 1e+69, 1e+70, 1e+71, 1e+72, 1e+73, 1e+74, 1e+75, 1e+76, 1e+77, 1e+78, 1e+79, 1e+80, + 1e+81, 1e+82, 1e+83, 1e+84, 1e+85, 1e+86, 1e+87, 1e+88, 1e+89, 1e+90, 1e+91, 1e+92, 1e+93, 1e+94, 1e+95, 1e+96, 1e+97, 1e+98, 1e+99, 1e+100, + 1e+101,1e+102,1e+103,1e+104,1e+105,1e+106,1e+107,1e+108,1e+109,1e+110,1e+111,1e+112,1e+113,1e+114,1e+115,1e+116,1e+117,1e+118,1e+119,1e+120, + 1e+121,1e+122,1e+123,1e+124,1e+125,1e+126,1e+127,1e+128,1e+129,1e+130,1e+131,1e+132,1e+133,1e+134,1e+135,1e+136,1e+137,1e+138,1e+139,1e+140, + 1e+141,1e+142,1e+143,1e+144,1e+145,1e+146,1e+147,1e+148,1e+149,1e+150,1e+151,1e+152,1e+153,1e+154,1e+155,1e+156,1e+157,1e+158,1e+159,1e+160, + 1e+161,1e+162,1e+163,1e+164,1e+165,1e+166,1e+167,1e+168,1e+169,1e+170,1e+171,1e+172,1e+173,1e+174,1e+175,1e+176,1e+177,1e+178,1e+179,1e+180, + 1e+181,1e+182,1e+183,1e+184,1e+185,1e+186,1e+187,1e+188,1e+189,1e+190,1e+191,1e+192,1e+193,1e+194,1e+195,1e+196,1e+197,1e+198,1e+199,1e+200, + 1e+201,1e+202,1e+203,1e+204,1e+205,1e+206,1e+207,1e+208,1e+209,1e+210,1e+211,1e+212,1e+213,1e+214,1e+215,1e+216,1e+217,1e+218,1e+219,1e+220, + 1e+221,1e+222,1e+223,1e+224,1e+225,1e+226,1e+227,1e+228,1e+229,1e+230,1e+231,1e+232,1e+233,1e+234,1e+235,1e+236,1e+237,1e+238,1e+239,1e+240, + 1e+241,1e+242,1e+243,1e+244,1e+245,1e+246,1e+247,1e+248,1e+249,1e+250,1e+251,1e+252,1e+253,1e+254,1e+255,1e+256,1e+257,1e+258,1e+259,1e+260, + 1e+261,1e+262,1e+263,1e+264,1e+265,1e+266,1e+267,1e+268,1e+269,1e+270,1e+271,1e+272,1e+273,1e+274,1e+275,1e+276,1e+277,1e+278,1e+279,1e+280, + 1e+281,1e+282,1e+283,1e+284,1e+285,1e+286,1e+287,1e+288,1e+289,1e+290,1e+291,1e+292,1e+293,1e+294,1e+295,1e+296,1e+297,1e+298,1e+299,1e+300, + 1e+301,1e+302,1e+303,1e+304,1e+305,1e+306,1e+307,1e+308 + }; + RAPIDJSON_ASSERT(n >= 0 && n <= 308); + return e[n]; +} + +} // namespace internal +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_POW10_ diff --git a/include/rapidjson/internal/regex.h b/include/rapidjson/internal/regex.h new file mode 100644 index 0000000000..7740dcd527 --- /dev/null +++ b/include/rapidjson/internal/regex.h @@ -0,0 +1,739 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_INTERNAL_REGEX_H_ +#define RAPIDJSON_INTERNAL_REGEX_H_ + +#include "../allocators.h" +#include "../stream.h" +#include "stack.h" + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(padded) +RAPIDJSON_DIAG_OFF(switch-enum) +#elif defined(_MSC_VER) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated +#endif + +#ifdef __GNUC__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(effc++) +#endif + +#ifndef RAPIDJSON_REGEX_VERBOSE +#define RAPIDJSON_REGEX_VERBOSE 0 +#endif + +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +/////////////////////////////////////////////////////////////////////////////// +// DecodedStream + +template +class DecodedStream { +public: + DecodedStream(SourceStream& ss) : ss_(ss), codepoint_() { Decode(); } + unsigned Peek() { return codepoint_; } + unsigned Take() { + unsigned c = codepoint_; + if (c) // No further decoding when '\0' + Decode(); + return c; + } + +private: + void Decode() { + if (!Encoding::Decode(ss_, &codepoint_)) + codepoint_ = 0; + } + + SourceStream& ss_; + unsigned codepoint_; +}; + +/////////////////////////////////////////////////////////////////////////////// +// GenericRegex + +static const SizeType kRegexInvalidState = ~SizeType(0); //!< Represents an invalid index in GenericRegex::State::out, out1 +static const SizeType kRegexInvalidRange = ~SizeType(0); + +template +class GenericRegexSearch; + +//! Regular expression engine with subset of ECMAscript grammar. +/*! + Supported regular expression syntax: + - \c ab Concatenation + - \c a|b Alternation + - \c a? Zero or one + - \c a* Zero or more + - \c a+ One or more + - \c a{3} Exactly 3 times + - \c a{3,} At least 3 times + - \c a{3,5} 3 to 5 times + - \c (ab) Grouping + - \c ^a At the beginning + - \c a$ At the end + - \c . Any character + - \c [abc] Character classes + - \c [a-c] Character class range + - \c [a-z0-9_] Character class combination + - \c [^abc] Negated character classes + - \c [^a-c] Negated character class range + - \c [\b] Backspace (U+0008) + - \c \\| \\\\ ... Escape characters + - \c \\f Form feed (U+000C) + - \c \\n Line feed (U+000A) + - \c \\r Carriage return (U+000D) + - \c \\t Tab (U+0009) + - \c \\v Vertical tab (U+000B) + + \note This is a Thompson NFA engine, implemented with reference to + Cox, Russ. "Regular Expression Matching Can Be Simple And Fast (but is slow in Java, Perl, PHP, Python, Ruby,...).", + https://swtch.com/~rsc/regexp/regexp1.html +*/ +template +class GenericRegex { +public: + typedef Encoding EncodingType; + typedef typename Encoding::Ch Ch; + template friend class GenericRegexSearch; + + GenericRegex(const Ch* source, Allocator* allocator = 0) : + ownAllocator_(allocator ? 0 : RAPIDJSON_NEW(Allocator)()), allocator_(allocator ? allocator : ownAllocator_), + states_(allocator_, 256), ranges_(allocator_, 256), root_(kRegexInvalidState), stateCount_(), rangeCount_(), + anchorBegin_(), anchorEnd_() + { + GenericStringStream ss(source); + DecodedStream, Encoding> ds(ss); + Parse(ds); + } + + ~GenericRegex() + { + RAPIDJSON_DELETE(ownAllocator_); + } + + bool IsValid() const { + return root_ != kRegexInvalidState; + } + +private: + enum Operator { + kZeroOrOne, + kZeroOrMore, + kOneOrMore, + kConcatenation, + kAlternation, + kLeftParenthesis + }; + + static const unsigned kAnyCharacterClass = 0xFFFFFFFF; //!< For '.' + static const unsigned kRangeCharacterClass = 0xFFFFFFFE; + static const unsigned kRangeNegationFlag = 0x80000000; + + struct Range { + unsigned start; // + unsigned end; + SizeType next; + }; + + struct State { + SizeType out; //!< Equals to kInvalid for matching state + SizeType out1; //!< Equals to non-kInvalid for split + SizeType rangeStart; + unsigned codepoint; + }; + + struct Frag { + Frag(SizeType s, SizeType o, SizeType m) : start(s), out(o), minIndex(m) {} + SizeType start; + SizeType out; //!< link-list of all output states + SizeType minIndex; + }; + + State& GetState(SizeType index) { + RAPIDJSON_ASSERT(index < stateCount_); + return states_.template Bottom()[index]; + } + + const State& GetState(SizeType index) const { + RAPIDJSON_ASSERT(index < stateCount_); + return states_.template Bottom()[index]; + } + + Range& GetRange(SizeType index) { + RAPIDJSON_ASSERT(index < rangeCount_); + return ranges_.template Bottom()[index]; + } + + const Range& GetRange(SizeType index) const { + RAPIDJSON_ASSERT(index < rangeCount_); + return ranges_.template Bottom()[index]; + } + + template + void Parse(DecodedStream& ds) { + Stack operandStack(allocator_, 256); // Frag + Stack operatorStack(allocator_, 256); // Operator + Stack atomCountStack(allocator_, 256); // unsigned (Atom per parenthesis) + + *atomCountStack.template Push() = 0; + + unsigned codepoint; + while (ds.Peek() != 0) { + switch (codepoint = ds.Take()) { + case '^': + anchorBegin_ = true; + break; + + case '$': + anchorEnd_ = true; + break; + + case '|': + while (!operatorStack.Empty() && *operatorStack.template Top() < kAlternation) + if (!Eval(operandStack, *operatorStack.template Pop(1))) + return; + *operatorStack.template Push() = kAlternation; + *atomCountStack.template Top() = 0; + break; + + case '(': + *operatorStack.template Push() = kLeftParenthesis; + *atomCountStack.template Push() = 0; + break; + + case ')': + while (!operatorStack.Empty() && *operatorStack.template Top() != kLeftParenthesis) + if (!Eval(operandStack, *operatorStack.template Pop(1))) + return; + if (operatorStack.Empty()) + return; + operatorStack.template Pop(1); + atomCountStack.template Pop(1); + ImplicitConcatenation(atomCountStack, operatorStack); + break; + + case '?': + if (!Eval(operandStack, kZeroOrOne)) + return; + break; + + case '*': + if (!Eval(operandStack, kZeroOrMore)) + return; + break; + + case '+': + if (!Eval(operandStack, kOneOrMore)) + return; + break; + + case '{': + { + unsigned n, m; + if (!ParseUnsigned(ds, &n)) + return; + + if (ds.Peek() == ',') { + ds.Take(); + if (ds.Peek() == '}') + m = kInfinityQuantifier; + else if (!ParseUnsigned(ds, &m) || m < n) + return; + } + else + m = n; + + if (!EvalQuantifier(operandStack, n, m) || ds.Peek() != '}') + return; + ds.Take(); + } + break; + + case '.': + PushOperand(operandStack, kAnyCharacterClass); + ImplicitConcatenation(atomCountStack, operatorStack); + break; + + case '[': + { + SizeType range; + if (!ParseRange(ds, &range)) + return; + SizeType s = NewState(kRegexInvalidState, kRegexInvalidState, kRangeCharacterClass); + GetState(s).rangeStart = range; + *operandStack.template Push() = Frag(s, s, s); + } + ImplicitConcatenation(atomCountStack, operatorStack); + break; + + case '\\': // Escape character + if (!CharacterEscape(ds, &codepoint)) + return; // Unsupported escape character + // fall through to default + RAPIDJSON_DELIBERATE_FALLTHROUGH; + + default: // Pattern character + PushOperand(operandStack, codepoint); + ImplicitConcatenation(atomCountStack, operatorStack); + } + } + + while (!operatorStack.Empty()) + if (!Eval(operandStack, *operatorStack.template Pop(1))) + return; + + // Link the operand to matching state. + if (operandStack.GetSize() == sizeof(Frag)) { + Frag* e = operandStack.template Pop(1); + Patch(e->out, NewState(kRegexInvalidState, kRegexInvalidState, 0)); + root_ = e->start; + +#if RAPIDJSON_REGEX_VERBOSE + printf("root: %d\n", root_); + for (SizeType i = 0; i < stateCount_ ; i++) { + State& s = GetState(i); + printf("[%2d] out: %2d out1: %2d c: '%c'\n", i, s.out, s.out1, (char)s.codepoint); + } + printf("\n"); +#endif + } + } + + SizeType NewState(SizeType out, SizeType out1, unsigned codepoint) { + State* s = states_.template Push(); + s->out = out; + s->out1 = out1; + s->codepoint = codepoint; + s->rangeStart = kRegexInvalidRange; + return stateCount_++; + } + + void PushOperand(Stack& operandStack, unsigned codepoint) { + SizeType s = NewState(kRegexInvalidState, kRegexInvalidState, codepoint); + *operandStack.template Push() = Frag(s, s, s); + } + + void ImplicitConcatenation(Stack& atomCountStack, Stack& operatorStack) { + if (*atomCountStack.template Top()) + *operatorStack.template Push() = kConcatenation; + (*atomCountStack.template Top())++; + } + + SizeType Append(SizeType l1, SizeType l2) { + SizeType old = l1; + while (GetState(l1).out != kRegexInvalidState) + l1 = GetState(l1).out; + GetState(l1).out = l2; + return old; + } + + void Patch(SizeType l, SizeType s) { + for (SizeType next; l != kRegexInvalidState; l = next) { + next = GetState(l).out; + GetState(l).out = s; + } + } + + bool Eval(Stack& operandStack, Operator op) { + switch (op) { + case kConcatenation: + RAPIDJSON_ASSERT(operandStack.GetSize() >= sizeof(Frag) * 2); + { + Frag e2 = *operandStack.template Pop(1); + Frag e1 = *operandStack.template Pop(1); + Patch(e1.out, e2.start); + *operandStack.template Push() = Frag(e1.start, e2.out, Min(e1.minIndex, e2.minIndex)); + } + return true; + + case kAlternation: + if (operandStack.GetSize() >= sizeof(Frag) * 2) { + Frag e2 = *operandStack.template Pop(1); + Frag e1 = *operandStack.template Pop(1); + SizeType s = NewState(e1.start, e2.start, 0); + *operandStack.template Push() = Frag(s, Append(e1.out, e2.out), Min(e1.minIndex, e2.minIndex)); + return true; + } + return false; + + case kZeroOrOne: + if (operandStack.GetSize() >= sizeof(Frag)) { + Frag e = *operandStack.template Pop(1); + SizeType s = NewState(kRegexInvalidState, e.start, 0); + *operandStack.template Push() = Frag(s, Append(e.out, s), e.minIndex); + return true; + } + return false; + + case kZeroOrMore: + if (operandStack.GetSize() >= sizeof(Frag)) { + Frag e = *operandStack.template Pop(1); + SizeType s = NewState(kRegexInvalidState, e.start, 0); + Patch(e.out, s); + *operandStack.template Push() = Frag(s, s, e.minIndex); + return true; + } + return false; + + case kOneOrMore: + if (operandStack.GetSize() >= sizeof(Frag)) { + Frag e = *operandStack.template Pop(1); + SizeType s = NewState(kRegexInvalidState, e.start, 0); + Patch(e.out, s); + *operandStack.template Push() = Frag(e.start, s, e.minIndex); + return true; + } + return false; + + default: + // syntax error (e.g. unclosed kLeftParenthesis) + return false; + } + } + + bool EvalQuantifier(Stack& operandStack, unsigned n, unsigned m) { + RAPIDJSON_ASSERT(n <= m); + RAPIDJSON_ASSERT(operandStack.GetSize() >= sizeof(Frag)); + + if (n == 0) { + if (m == 0) // a{0} not support + return false; + else if (m == kInfinityQuantifier) + Eval(operandStack, kZeroOrMore); // a{0,} -> a* + else { + Eval(operandStack, kZeroOrOne); // a{0,5} -> a? + for (unsigned i = 0; i < m - 1; i++) + CloneTopOperand(operandStack); // a{0,5} -> a? a? a? a? a? + for (unsigned i = 0; i < m - 1; i++) + Eval(operandStack, kConcatenation); // a{0,5} -> a?a?a?a?a? + } + return true; + } + + for (unsigned i = 0; i < n - 1; i++) // a{3} -> a a a + CloneTopOperand(operandStack); + + if (m == kInfinityQuantifier) + Eval(operandStack, kOneOrMore); // a{3,} -> a a a+ + else if (m > n) { + CloneTopOperand(operandStack); // a{3,5} -> a a a a + Eval(operandStack, kZeroOrOne); // a{3,5} -> a a a a? + for (unsigned i = n; i < m - 1; i++) + CloneTopOperand(operandStack); // a{3,5} -> a a a a? a? + for (unsigned i = n; i < m; i++) + Eval(operandStack, kConcatenation); // a{3,5} -> a a aa?a? + } + + for (unsigned i = 0; i < n - 1; i++) + Eval(operandStack, kConcatenation); // a{3} -> aaa, a{3,} -> aaa+, a{3.5} -> aaaa?a? + + return true; + } + + static SizeType Min(SizeType a, SizeType b) { return a < b ? a : b; } + + void CloneTopOperand(Stack& operandStack) { + const Frag src = *operandStack.template Top(); // Copy constructor to prevent invalidation + SizeType count = stateCount_ - src.minIndex; // Assumes top operand contains states in [src->minIndex, stateCount_) + State* s = states_.template Push(count); + memcpy(s, &GetState(src.minIndex), count * sizeof(State)); + for (SizeType j = 0; j < count; j++) { + if (s[j].out != kRegexInvalidState) + s[j].out += count; + if (s[j].out1 != kRegexInvalidState) + s[j].out1 += count; + } + *operandStack.template Push() = Frag(src.start + count, src.out + count, src.minIndex + count); + stateCount_ += count; + } + + template + bool ParseUnsigned(DecodedStream& ds, unsigned* u) { + unsigned r = 0; + if (ds.Peek() < '0' || ds.Peek() > '9') + return false; + while (ds.Peek() >= '0' && ds.Peek() <= '9') { + if (r >= 429496729 && ds.Peek() > '5') // 2^32 - 1 = 4294967295 + return false; // overflow + r = r * 10 + (ds.Take() - '0'); + } + *u = r; + return true; + } + + template + bool ParseRange(DecodedStream& ds, SizeType* range) { + bool isBegin = true; + bool negate = false; + int step = 0; + SizeType start = kRegexInvalidRange; + SizeType current = kRegexInvalidRange; + unsigned codepoint; + while ((codepoint = ds.Take()) != 0) { + if (isBegin) { + isBegin = false; + if (codepoint == '^') { + negate = true; + continue; + } + } + + switch (codepoint) { + case ']': + if (start == kRegexInvalidRange) + return false; // Error: nothing inside [] + if (step == 2) { // Add trailing '-' + SizeType r = NewRange('-'); + RAPIDJSON_ASSERT(current != kRegexInvalidRange); + GetRange(current).next = r; + } + if (negate) + GetRange(start).start |= kRangeNegationFlag; + *range = start; + return true; + + case '\\': + if (ds.Peek() == 'b') { + ds.Take(); + codepoint = 0x0008; // Escape backspace character + } + else if (!CharacterEscape(ds, &codepoint)) + return false; + // fall through to default + RAPIDJSON_DELIBERATE_FALLTHROUGH; + + default: + switch (step) { + case 1: + if (codepoint == '-') { + step++; + break; + } + // fall through to step 0 for other characters + RAPIDJSON_DELIBERATE_FALLTHROUGH; + + case 0: + { + SizeType r = NewRange(codepoint); + if (current != kRegexInvalidRange) + GetRange(current).next = r; + if (start == kRegexInvalidRange) + start = r; + current = r; + } + step = 1; + break; + + default: + RAPIDJSON_ASSERT(step == 2); + GetRange(current).end = codepoint; + step = 0; + } + } + } + return false; + } + + SizeType NewRange(unsigned codepoint) { + Range* r = ranges_.template Push(); + r->start = r->end = codepoint; + r->next = kRegexInvalidRange; + return rangeCount_++; + } + + template + bool CharacterEscape(DecodedStream& ds, unsigned* escapedCodepoint) { + unsigned codepoint; + switch (codepoint = ds.Take()) { + case '^': + case '$': + case '|': + case '(': + case ')': + case '?': + case '*': + case '+': + case '.': + case '[': + case ']': + case '{': + case '}': + case '\\': + *escapedCodepoint = codepoint; return true; + case 'f': *escapedCodepoint = 0x000C; return true; + case 'n': *escapedCodepoint = 0x000A; return true; + case 'r': *escapedCodepoint = 0x000D; return true; + case 't': *escapedCodepoint = 0x0009; return true; + case 'v': *escapedCodepoint = 0x000B; return true; + default: + return false; // Unsupported escape character + } + } + + Allocator* ownAllocator_; + Allocator* allocator_; + Stack states_; + Stack ranges_; + SizeType root_; + SizeType stateCount_; + SizeType rangeCount_; + + static const unsigned kInfinityQuantifier = ~0u; + + // For SearchWithAnchoring() + bool anchorBegin_; + bool anchorEnd_; +}; + +template +class GenericRegexSearch { +public: + typedef typename RegexType::EncodingType Encoding; + typedef typename Encoding::Ch Ch; + + GenericRegexSearch(const RegexType& regex, Allocator* allocator = 0) : + regex_(regex), allocator_(allocator), ownAllocator_(0), + state0_(allocator, 0), state1_(allocator, 0), stateSet_() + { + RAPIDJSON_ASSERT(regex_.IsValid()); + if (!allocator_) + ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)(); + stateSet_ = static_cast(allocator_->Malloc(GetStateSetSize())); + state0_.template Reserve(regex_.stateCount_); + state1_.template Reserve(regex_.stateCount_); + } + + ~GenericRegexSearch() { + Allocator::Free(stateSet_); + RAPIDJSON_DELETE(ownAllocator_); + } + + template + bool Match(InputStream& is) { + return SearchWithAnchoring(is, true, true); + } + + bool Match(const Ch* s) { + GenericStringStream is(s); + return Match(is); + } + + template + bool Search(InputStream& is) { + return SearchWithAnchoring(is, regex_.anchorBegin_, regex_.anchorEnd_); + } + + bool Search(const Ch* s) { + GenericStringStream is(s); + return Search(is); + } + +private: + typedef typename RegexType::State State; + typedef typename RegexType::Range Range; + + template + bool SearchWithAnchoring(InputStream& is, bool anchorBegin, bool anchorEnd) { + DecodedStream ds(is); + + state0_.Clear(); + Stack *current = &state0_, *next = &state1_; + const size_t stateSetSize = GetStateSetSize(); + std::memset(stateSet_, 0, stateSetSize); + + bool matched = AddState(*current, regex_.root_); + unsigned codepoint; + while (!current->Empty() && (codepoint = ds.Take()) != 0) { + std::memset(stateSet_, 0, stateSetSize); + next->Clear(); + matched = false; + for (const SizeType* s = current->template Bottom(); s != current->template End(); ++s) { + const State& sr = regex_.GetState(*s); + if (sr.codepoint == codepoint || + sr.codepoint == RegexType::kAnyCharacterClass || + (sr.codepoint == RegexType::kRangeCharacterClass && MatchRange(sr.rangeStart, codepoint))) + { + matched = AddState(*next, sr.out) || matched; + if (!anchorEnd && matched) + return true; + } + if (!anchorBegin) + AddState(*next, regex_.root_); + } + internal::Swap(current, next); + } + + return matched; + } + + size_t GetStateSetSize() const { + return (regex_.stateCount_ + 31) / 32 * 4; + } + + // Return whether the added states is a match state + bool AddState(Stack& l, SizeType index) { + RAPIDJSON_ASSERT(index != kRegexInvalidState); + + const State& s = regex_.GetState(index); + if (s.out1 != kRegexInvalidState) { // Split + bool matched = AddState(l, s.out); + return AddState(l, s.out1) || matched; + } + else if (!(stateSet_[index >> 5] & (1u << (index & 31)))) { + stateSet_[index >> 5] |= (1u << (index & 31)); + *l.template PushUnsafe() = index; + } + return s.out == kRegexInvalidState; // by using PushUnsafe() above, we can ensure s is not validated due to reallocation. + } + + bool MatchRange(SizeType rangeIndex, unsigned codepoint) const { + bool yes = (regex_.GetRange(rangeIndex).start & RegexType::kRangeNegationFlag) == 0; + while (rangeIndex != kRegexInvalidRange) { + const Range& r = regex_.GetRange(rangeIndex); + if (codepoint >= (r.start & ~RegexType::kRangeNegationFlag) && codepoint <= r.end) + return yes; + rangeIndex = r.next; + } + return !yes; + } + + const RegexType& regex_; + Allocator* allocator_; + Allocator* ownAllocator_; + Stack state0_; + Stack state1_; + uint32_t* stateSet_; +}; + +typedef GenericRegex > Regex; +typedef GenericRegexSearch RegexSearch; + +} // namespace internal +RAPIDJSON_NAMESPACE_END + +#ifdef __GNUC__ +RAPIDJSON_DIAG_POP +#endif + +#if defined(__clang__) || defined(_MSC_VER) +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_INTERNAL_REGEX_H_ diff --git a/include/rapidjson/internal/stack.h b/include/rapidjson/internal/stack.h new file mode 100644 index 0000000000..73abd706e9 --- /dev/null +++ b/include/rapidjson/internal/stack.h @@ -0,0 +1,232 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_INTERNAL_STACK_H_ +#define RAPIDJSON_INTERNAL_STACK_H_ + +#include "../allocators.h" +#include "swap.h" +#include + +#if defined(__clang__) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(c++98-compat) +#endif + +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +/////////////////////////////////////////////////////////////////////////////// +// Stack + +//! A type-unsafe stack for storing different types of data. +/*! \tparam Allocator Allocator for allocating stack memory. +*/ +template +class Stack { +public: + // Optimization note: Do not allocate memory for stack_ in constructor. + // Do it lazily when first Push() -> Expand() -> Resize(). + Stack(Allocator* allocator, size_t stackCapacity) : allocator_(allocator), ownAllocator_(0), stack_(0), stackTop_(0), stackEnd_(0), initialCapacity_(stackCapacity) { + } + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + Stack(Stack&& rhs) + : allocator_(rhs.allocator_), + ownAllocator_(rhs.ownAllocator_), + stack_(rhs.stack_), + stackTop_(rhs.stackTop_), + stackEnd_(rhs.stackEnd_), + initialCapacity_(rhs.initialCapacity_) + { + rhs.allocator_ = 0; + rhs.ownAllocator_ = 0; + rhs.stack_ = 0; + rhs.stackTop_ = 0; + rhs.stackEnd_ = 0; + rhs.initialCapacity_ = 0; + } +#endif + + ~Stack() { + Destroy(); + } + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + Stack& operator=(Stack&& rhs) { + if (&rhs != this) + { + Destroy(); + + allocator_ = rhs.allocator_; + ownAllocator_ = rhs.ownAllocator_; + stack_ = rhs.stack_; + stackTop_ = rhs.stackTop_; + stackEnd_ = rhs.stackEnd_; + initialCapacity_ = rhs.initialCapacity_; + + rhs.allocator_ = 0; + rhs.ownAllocator_ = 0; + rhs.stack_ = 0; + rhs.stackTop_ = 0; + rhs.stackEnd_ = 0; + rhs.initialCapacity_ = 0; + } + return *this; + } +#endif + + void Swap(Stack& rhs) RAPIDJSON_NOEXCEPT { + internal::Swap(allocator_, rhs.allocator_); + internal::Swap(ownAllocator_, rhs.ownAllocator_); + internal::Swap(stack_, rhs.stack_); + internal::Swap(stackTop_, rhs.stackTop_); + internal::Swap(stackEnd_, rhs.stackEnd_); + internal::Swap(initialCapacity_, rhs.initialCapacity_); + } + + void Clear() { stackTop_ = stack_; } + + void ShrinkToFit() { + if (Empty()) { + // If the stack is empty, completely deallocate the memory. + Allocator::Free(stack_); // NOLINT (+clang-analyzer-unix.Malloc) + stack_ = 0; + stackTop_ = 0; + stackEnd_ = 0; + } + else + Resize(GetSize()); + } + + // Optimization note: try to minimize the size of this function for force inline. + // Expansion is run very infrequently, so it is moved to another (probably non-inline) function. + template + RAPIDJSON_FORCEINLINE void Reserve(size_t count = 1) { + // Expand the stack if needed + if (RAPIDJSON_UNLIKELY(static_cast(sizeof(T) * count) > (stackEnd_ - stackTop_))) + Expand(count); + } + + template + RAPIDJSON_FORCEINLINE T* Push(size_t count = 1) { + Reserve(count); + return PushUnsafe(count); + } + + template + RAPIDJSON_FORCEINLINE T* PushUnsafe(size_t count = 1) { + RAPIDJSON_ASSERT(stackTop_); + RAPIDJSON_ASSERT(static_cast(sizeof(T) * count) <= (stackEnd_ - stackTop_)); + T* ret = reinterpret_cast(stackTop_); + stackTop_ += sizeof(T) * count; + return ret; + } + + template + T* Pop(size_t count) { + RAPIDJSON_ASSERT(GetSize() >= count * sizeof(T)); + stackTop_ -= count * sizeof(T); + return reinterpret_cast(stackTop_); + } + + template + T* Top() { + RAPIDJSON_ASSERT(GetSize() >= sizeof(T)); + return reinterpret_cast(stackTop_ - sizeof(T)); + } + + template + const T* Top() const { + RAPIDJSON_ASSERT(GetSize() >= sizeof(T)); + return reinterpret_cast(stackTop_ - sizeof(T)); + } + + template + T* End() { return reinterpret_cast(stackTop_); } + + template + const T* End() const { return reinterpret_cast(stackTop_); } + + template + T* Bottom() { return reinterpret_cast(stack_); } + + template + const T* Bottom() const { return reinterpret_cast(stack_); } + + bool HasAllocator() const { + return allocator_ != 0; + } + + Allocator& GetAllocator() { + RAPIDJSON_ASSERT(allocator_); + return *allocator_; + } + + bool Empty() const { return stackTop_ == stack_; } + size_t GetSize() const { return static_cast(stackTop_ - stack_); } + size_t GetCapacity() const { return static_cast(stackEnd_ - stack_); } + +private: + template + void Expand(size_t count) { + // Only expand the capacity if the current stack exists. Otherwise just create a stack with initial capacity. + size_t newCapacity; + if (stack_ == 0) { + if (!allocator_) + ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)(); + newCapacity = initialCapacity_; + } else { + newCapacity = GetCapacity(); + newCapacity += (newCapacity + 1) / 2; + } + size_t newSize = GetSize() + sizeof(T) * count; + if (newCapacity < newSize) + newCapacity = newSize; + + Resize(newCapacity); + } + + void Resize(size_t newCapacity) { + const size_t size = GetSize(); // Backup the current size + stack_ = static_cast(allocator_->Realloc(stack_, GetCapacity(), newCapacity)); + stackTop_ = stack_ + size; + stackEnd_ = stack_ + newCapacity; + } + + void Destroy() { + Allocator::Free(stack_); + RAPIDJSON_DELETE(ownAllocator_); // Only delete if it is owned by the stack + } + + // Prohibit copy constructor & assignment operator. + Stack(const Stack&); + Stack& operator=(const Stack&); + + Allocator* allocator_; + Allocator* ownAllocator_; + char *stack_; + char *stackTop_; + char *stackEnd_; + size_t initialCapacity_; +}; + +} // namespace internal +RAPIDJSON_NAMESPACE_END + +#if defined(__clang__) +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_STACK_H_ diff --git a/include/rapidjson/internal/strfunc.h b/include/rapidjson/internal/strfunc.h new file mode 100644 index 0000000000..b698a8f43f --- /dev/null +++ b/include/rapidjson/internal/strfunc.h @@ -0,0 +1,83 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_INTERNAL_STRFUNC_H_ +#define RAPIDJSON_INTERNAL_STRFUNC_H_ + +#include "../stream.h" +#include + +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +//! Custom strlen() which works on different character types. +/*! \tparam Ch Character type (e.g. char, wchar_t, short) + \param s Null-terminated input string. + \return Number of characters in the string. + \note This has the same semantics as strlen(), the return value is not number of Unicode codepoints. +*/ +template +inline SizeType StrLen(const Ch* s) { + RAPIDJSON_ASSERT(s != 0); + const Ch* p = s; + while (*p) ++p; + return SizeType(p - s); +} + +template <> +inline SizeType StrLen(const char* s) { + return SizeType(std::strlen(s)); +} + +template <> +inline SizeType StrLen(const wchar_t* s) { + return SizeType(std::wcslen(s)); +} + +//! Custom strcmpn() which works on different character types. +/*! \tparam Ch Character type (e.g. char, wchar_t, short) + \param s1 Null-terminated input string. + \param s2 Null-terminated input string. + \return 0 if equal +*/ +template +inline int StrCmp(const Ch* s1, const Ch* s2) { + RAPIDJSON_ASSERT(s1 != 0); + RAPIDJSON_ASSERT(s2 != 0); + while(*s1 && (*s1 == *s2)) { s1++; s2++; } + return static_cast(*s1) < static_cast(*s2) ? -1 : static_cast(*s1) > static_cast(*s2); +} + +//! Returns number of code points in a encoded string. +template +bool CountStringCodePoint(const typename Encoding::Ch* s, SizeType length, SizeType* outCount) { + RAPIDJSON_ASSERT(s != 0); + RAPIDJSON_ASSERT(outCount != 0); + GenericStringStream is(s); + const typename Encoding::Ch* end = s + length; + SizeType count = 0; + while (is.src_ < end) { + unsigned codepoint; + if (!Encoding::Decode(is, &codepoint)) + return false; + count++; + } + *outCount = count; + return true; +} + +} // namespace internal +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_INTERNAL_STRFUNC_H_ diff --git a/include/rapidjson/internal/strtod.h b/include/rapidjson/internal/strtod.h new file mode 100644 index 0000000000..57c8418bd9 --- /dev/null +++ b/include/rapidjson/internal/strtod.h @@ -0,0 +1,293 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_STRTOD_ +#define RAPIDJSON_STRTOD_ + +#include "ieee754.h" +#include "biginteger.h" +#include "diyfp.h" +#include "pow10.h" +#include +#include + +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +inline double FastPath(double significand, int exp) { + if (exp < -308) + return 0.0; + else if (exp >= 0) + return significand * internal::Pow10(exp); + else + return significand / internal::Pow10(-exp); +} + +inline double StrtodNormalPrecision(double d, int p) { + if (p < -308) { + // Prevent expSum < -308, making Pow10(p) = 0 + d = FastPath(d, -308); + d = FastPath(d, p + 308); + } + else + d = FastPath(d, p); + return d; +} + +template +inline T Min3(T a, T b, T c) { + T m = a; + if (m > b) m = b; + if (m > c) m = c; + return m; +} + +inline int CheckWithinHalfULP(double b, const BigInteger& d, int dExp) { + const Double db(b); + const uint64_t bInt = db.IntegerSignificand(); + const int bExp = db.IntegerExponent(); + const int hExp = bExp - 1; + + int dS_Exp2 = 0, dS_Exp5 = 0, bS_Exp2 = 0, bS_Exp5 = 0, hS_Exp2 = 0, hS_Exp5 = 0; + + // Adjust for decimal exponent + if (dExp >= 0) { + dS_Exp2 += dExp; + dS_Exp5 += dExp; + } + else { + bS_Exp2 -= dExp; + bS_Exp5 -= dExp; + hS_Exp2 -= dExp; + hS_Exp5 -= dExp; + } + + // Adjust for binary exponent + if (bExp >= 0) + bS_Exp2 += bExp; + else { + dS_Exp2 -= bExp; + hS_Exp2 -= bExp; + } + + // Adjust for half ulp exponent + if (hExp >= 0) + hS_Exp2 += hExp; + else { + dS_Exp2 -= hExp; + bS_Exp2 -= hExp; + } + + // Remove common power of two factor from all three scaled values + int common_Exp2 = Min3(dS_Exp2, bS_Exp2, hS_Exp2); + dS_Exp2 -= common_Exp2; + bS_Exp2 -= common_Exp2; + hS_Exp2 -= common_Exp2; + + BigInteger dS = d; + dS.MultiplyPow5(static_cast(dS_Exp5)) <<= static_cast(dS_Exp2); + + BigInteger bS(bInt); + bS.MultiplyPow5(static_cast(bS_Exp5)) <<= static_cast(bS_Exp2); + + BigInteger hS(1); + hS.MultiplyPow5(static_cast(hS_Exp5)) <<= static_cast(hS_Exp2); + + BigInteger delta(0); + dS.Difference(bS, &delta); + + return delta.Compare(hS); +} + +inline bool StrtodFast(double d, int p, double* result) { + // Use fast path for string-to-double conversion if possible + // see http://www.exploringbinary.com/fast-path-decimal-to-floating-point-conversion/ + if (p > 22 && p < 22 + 16) { + // Fast Path Cases In Disguise + d *= internal::Pow10(p - 22); + p = 22; + } + + if (p >= -22 && p <= 22 && d <= 9007199254740991.0) { // 2^53 - 1 + *result = FastPath(d, p); + return true; + } + else + return false; +} + +// Compute an approximation and see if it is within 1/2 ULP +template +inline bool StrtodDiyFp(const Ch* decimals, int dLen, int dExp, double* result) { + uint64_t significand = 0; + int i = 0; // 2^64 - 1 = 18446744073709551615, 1844674407370955161 = 0x1999999999999999 + for (; i < dLen; i++) { + if (significand > RAPIDJSON_UINT64_C2(0x19999999, 0x99999999) || + (significand == RAPIDJSON_UINT64_C2(0x19999999, 0x99999999) && decimals[i] >= Ch('5'))) + break; + significand = significand * 10u + static_cast(decimals[i] - Ch('0')); + } + + if (i < dLen && decimals[i] >= Ch('5')) // Rounding + significand++; + + int remaining = dLen - i; + const int kUlpShift = 3; + const int kUlp = 1 << kUlpShift; + int64_t error = (remaining == 0) ? 0 : kUlp / 2; + + DiyFp v(significand, 0); + v = v.Normalize(); + error <<= -v.e; + + dExp += remaining; + + int actualExp; + DiyFp cachedPower = GetCachedPower10(dExp, &actualExp); + if (actualExp != dExp) { + static const DiyFp kPow10[] = { + DiyFp(RAPIDJSON_UINT64_C2(0xa0000000, 0x00000000), -60), // 10^1 + DiyFp(RAPIDJSON_UINT64_C2(0xc8000000, 0x00000000), -57), // 10^2 + DiyFp(RAPIDJSON_UINT64_C2(0xfa000000, 0x00000000), -54), // 10^3 + DiyFp(RAPIDJSON_UINT64_C2(0x9c400000, 0x00000000), -50), // 10^4 + DiyFp(RAPIDJSON_UINT64_C2(0xc3500000, 0x00000000), -47), // 10^5 + DiyFp(RAPIDJSON_UINT64_C2(0xf4240000, 0x00000000), -44), // 10^6 + DiyFp(RAPIDJSON_UINT64_C2(0x98968000, 0x00000000), -40) // 10^7 + }; + int adjustment = dExp - actualExp; + RAPIDJSON_ASSERT(adjustment >= 1 && adjustment < 8); + v = v * kPow10[adjustment - 1]; + if (dLen + adjustment > 19) // has more digits than decimal digits in 64-bit + error += kUlp / 2; + } + + v = v * cachedPower; + + error += kUlp + (error == 0 ? 0 : 1); + + const int oldExp = v.e; + v = v.Normalize(); + error <<= oldExp - v.e; + + const int effectiveSignificandSize = Double::EffectiveSignificandSize(64 + v.e); + int precisionSize = 64 - effectiveSignificandSize; + if (precisionSize + kUlpShift >= 64) { + int scaleExp = (precisionSize + kUlpShift) - 63; + v.f >>= scaleExp; + v.e += scaleExp; + error = (error >> scaleExp) + 1 + kUlp; + precisionSize -= scaleExp; + } + + DiyFp rounded(v.f >> precisionSize, v.e + precisionSize); + const uint64_t precisionBits = (v.f & ((uint64_t(1) << precisionSize) - 1)) * kUlp; + const uint64_t halfWay = (uint64_t(1) << (precisionSize - 1)) * kUlp; + if (precisionBits >= halfWay + static_cast(error)) { + rounded.f++; + if (rounded.f & (DiyFp::kDpHiddenBit << 1)) { // rounding overflows mantissa (issue #340) + rounded.f >>= 1; + rounded.e++; + } + } + + *result = rounded.ToDouble(); + + return halfWay - static_cast(error) >= precisionBits || precisionBits >= halfWay + static_cast(error); +} + +template +inline double StrtodBigInteger(double approx, const Ch* decimals, int dLen, int dExp) { + RAPIDJSON_ASSERT(dLen >= 0); + const BigInteger dInt(decimals, static_cast(dLen)); + Double a(approx); + int cmp = CheckWithinHalfULP(a.Value(), dInt, dExp); + if (cmp < 0) + return a.Value(); // within half ULP + else if (cmp == 0) { + // Round towards even + if (a.Significand() & 1) + return a.NextPositiveDouble(); + else + return a.Value(); + } + else // adjustment + return a.NextPositiveDouble(); +} + +template +inline double StrtodFullPrecision(double d, int p, const Ch* decimals, size_t length, size_t decimalPosition, int exp) { + RAPIDJSON_ASSERT(d >= 0.0); + RAPIDJSON_ASSERT(length >= 1); + + double result = 0.0; + if (StrtodFast(d, p, &result)) + return result; + + RAPIDJSON_ASSERT(length <= INT_MAX); + int dLen = static_cast(length); + + RAPIDJSON_ASSERT(length >= decimalPosition); + RAPIDJSON_ASSERT(length - decimalPosition <= INT_MAX); + int dExpAdjust = static_cast(length - decimalPosition); + + RAPIDJSON_ASSERT(exp >= INT_MIN + dExpAdjust); + int dExp = exp - dExpAdjust; + + // Make sure length+dExp does not overflow + RAPIDJSON_ASSERT(dExp <= INT_MAX - dLen); + + // Trim leading zeros + while (dLen > 0 && *decimals == '0') { + dLen--; + decimals++; + } + + // Trim trailing zeros + while (dLen > 0 && decimals[dLen - 1] == '0') { + dLen--; + dExp++; + } + + if (dLen == 0) { // Buffer only contains zeros. + return 0.0; + } + + // Trim right-most digits + const int kMaxDecimalDigit = 767 + 1; + if (dLen > kMaxDecimalDigit) { + dExp += dLen - kMaxDecimalDigit; + dLen = kMaxDecimalDigit; + } + + // If too small, underflow to zero. + // Any x <= 10^-324 is interpreted as zero. + if (dLen + dExp <= -324) + return 0.0; + + // If too large, overflow to infinity. + // Any x >= 10^309 is interpreted as +infinity. + if (dLen + dExp > 309) + return std::numeric_limits::infinity(); + + if (StrtodDiyFp(decimals, dLen, dExp, &result)) + return result; + + // Use approximation from StrtodDiyFp and make adjustment with BigInteger comparison + return StrtodBigInteger(result, decimals, dLen, dExp); +} + +} // namespace internal +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_STRTOD_ diff --git a/include/rapidjson/internal/swap.h b/include/rapidjson/internal/swap.h new file mode 100644 index 0000000000..2cf92f93a1 --- /dev/null +++ b/include/rapidjson/internal/swap.h @@ -0,0 +1,46 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_INTERNAL_SWAP_H_ +#define RAPIDJSON_INTERNAL_SWAP_H_ + +#include "../rapidjson.h" + +#if defined(__clang__) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(c++98-compat) +#endif + +RAPIDJSON_NAMESPACE_BEGIN +namespace internal { + +//! Custom swap() to avoid dependency on C++ header +/*! \tparam T Type of the arguments to swap, should be instantiated with primitive C++ types only. + \note This has the same semantics as std::swap(). +*/ +template +inline void Swap(T& a, T& b) RAPIDJSON_NOEXCEPT { + T tmp = a; + a = b; + b = tmp; +} + +} // namespace internal +RAPIDJSON_NAMESPACE_END + +#if defined(__clang__) +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_INTERNAL_SWAP_H_ diff --git a/include/rapidjson/istreamwrapper.h b/include/rapidjson/istreamwrapper.h new file mode 100644 index 0000000000..01437ec012 --- /dev/null +++ b/include/rapidjson/istreamwrapper.h @@ -0,0 +1,128 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_ISTREAMWRAPPER_H_ +#define RAPIDJSON_ISTREAMWRAPPER_H_ + +#include "stream.h" +#include +#include + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(padded) +#elif defined(_MSC_VER) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(4351) // new behavior: elements of array 'array' will be default initialized +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +//! Wrapper of \c std::basic_istream into RapidJSON's Stream concept. +/*! + The classes can be wrapped including but not limited to: + + - \c std::istringstream + - \c std::stringstream + - \c std::wistringstream + - \c std::wstringstream + - \c std::ifstream + - \c std::fstream + - \c std::wifstream + - \c std::wfstream + + \tparam StreamType Class derived from \c std::basic_istream. +*/ + +template +class BasicIStreamWrapper { +public: + typedef typename StreamType::char_type Ch; + + //! Constructor. + /*! + \param stream stream opened for read. + */ + BasicIStreamWrapper(StreamType &stream) : stream_(stream), buffer_(peekBuffer_), bufferSize_(4), bufferLast_(0), current_(buffer_), readCount_(0), count_(0), eof_(false) { + Read(); + } + + //! Constructor. + /*! + \param stream stream opened for read. + \param buffer user-supplied buffer. + \param bufferSize size of buffer in bytes. Must >=4 bytes. + */ + BasicIStreamWrapper(StreamType &stream, char* buffer, size_t bufferSize) : stream_(stream), buffer_(buffer), bufferSize_(bufferSize), bufferLast_(0), current_(buffer_), readCount_(0), count_(0), eof_(false) { + RAPIDJSON_ASSERT(bufferSize >= 4); + Read(); + } + + Ch Peek() const { return *current_; } + Ch Take() { Ch c = *current_; Read(); return c; } + size_t Tell() const { return count_ + static_cast(current_ - buffer_); } + + // Not implemented + void Put(Ch) { RAPIDJSON_ASSERT(false); } + void Flush() { RAPIDJSON_ASSERT(false); } + Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; } + size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; } + + // For encoding detection only. + const Ch* Peek4() const { + return (current_ + 4 - !eof_ <= bufferLast_) ? current_ : 0; + } + +private: + BasicIStreamWrapper(); + BasicIStreamWrapper(const BasicIStreamWrapper&); + BasicIStreamWrapper& operator=(const BasicIStreamWrapper&); + + void Read() { + if (current_ < bufferLast_) + ++current_; + else if (!eof_) { + count_ += readCount_; + readCount_ = bufferSize_; + bufferLast_ = buffer_ + readCount_ - 1; + current_ = buffer_; + + if (!stream_.read(buffer_, static_cast(bufferSize_))) { + readCount_ = static_cast(stream_.gcount()); + *(bufferLast_ = buffer_ + readCount_) = '\0'; + eof_ = true; + } + } + } + + StreamType &stream_; + Ch peekBuffer_[4], *buffer_; + size_t bufferSize_; + Ch *bufferLast_; + Ch *current_; + size_t readCount_; + size_t count_; //!< Number of characters read + bool eof_; +}; + +typedef BasicIStreamWrapper IStreamWrapper; +typedef BasicIStreamWrapper WIStreamWrapper; + +#if defined(__clang__) || defined(_MSC_VER) +RAPIDJSON_DIAG_POP +#endif + +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_ISTREAMWRAPPER_H_ diff --git a/include/rapidjson/memorybuffer.h b/include/rapidjson/memorybuffer.h new file mode 100644 index 0000000000..ffbc41ed1f --- /dev/null +++ b/include/rapidjson/memorybuffer.h @@ -0,0 +1,70 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_MEMORYBUFFER_H_ +#define RAPIDJSON_MEMORYBUFFER_H_ + +#include "stream.h" +#include "internal/stack.h" + +RAPIDJSON_NAMESPACE_BEGIN + +//! Represents an in-memory output byte stream. +/*! + This class is mainly for being wrapped by EncodedOutputStream or AutoUTFOutputStream. + + It is similar to FileWriteBuffer but the destination is an in-memory buffer instead of a file. + + Differences between MemoryBuffer and StringBuffer: + 1. StringBuffer has Encoding but MemoryBuffer is only a byte buffer. + 2. StringBuffer::GetString() returns a null-terminated string. MemoryBuffer::GetBuffer() returns a buffer without terminator. + + \tparam Allocator type for allocating memory buffer. + \note implements Stream concept +*/ +template +struct GenericMemoryBuffer { + typedef char Ch; // byte + + GenericMemoryBuffer(Allocator* allocator = 0, size_t capacity = kDefaultCapacity) : stack_(allocator, capacity) {} + + void Put(Ch c) { *stack_.template Push() = c; } + void Flush() {} + + void Clear() { stack_.Clear(); } + void ShrinkToFit() { stack_.ShrinkToFit(); } + Ch* Push(size_t count) { return stack_.template Push(count); } + void Pop(size_t count) { stack_.template Pop(count); } + + const Ch* GetBuffer() const { + return stack_.template Bottom(); + } + + size_t GetSize() const { return stack_.GetSize(); } + + static const size_t kDefaultCapacity = 256; + mutable internal::Stack stack_; +}; + +typedef GenericMemoryBuffer<> MemoryBuffer; + +//! Implement specialized version of PutN() with memset() for better performance. +template<> +inline void PutN(MemoryBuffer& memoryBuffer, char c, size_t n) { + std::memset(memoryBuffer.stack_.Push(n), c, n * sizeof(c)); +} + +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_MEMORYBUFFER_H_ diff --git a/include/rapidjson/memorystream.h b/include/rapidjson/memorystream.h new file mode 100644 index 0000000000..77af6c999e --- /dev/null +++ b/include/rapidjson/memorystream.h @@ -0,0 +1,71 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_MEMORYSTREAM_H_ +#define RAPIDJSON_MEMORYSTREAM_H_ + +#include "stream.h" + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(unreachable-code) +RAPIDJSON_DIAG_OFF(missing-noreturn) +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +//! Represents an in-memory input byte stream. +/*! + This class is mainly for being wrapped by EncodedInputStream or AutoUTFInputStream. + + It is similar to FileReadBuffer but the source is an in-memory buffer instead of a file. + + Differences between MemoryStream and StringStream: + 1. StringStream has encoding but MemoryStream is a byte stream. + 2. MemoryStream needs size of the source buffer and the buffer don't need to be null terminated. StringStream assume null-terminated string as source. + 3. MemoryStream supports Peek4() for encoding detection. StringStream is specified with an encoding so it should not have Peek4(). + \note implements Stream concept +*/ +struct MemoryStream { + typedef char Ch; // byte + + MemoryStream(const Ch *src, size_t size) : src_(src), begin_(src), end_(src + size), size_(size) {} + + Ch Peek() const { return RAPIDJSON_UNLIKELY(src_ == end_) ? '\0' : *src_; } + Ch Take() { return RAPIDJSON_UNLIKELY(src_ == end_) ? '\0' : *src_++; } + size_t Tell() const { return static_cast(src_ - begin_); } + + Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; } + void Put(Ch) { RAPIDJSON_ASSERT(false); } + void Flush() { RAPIDJSON_ASSERT(false); } + size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; } + + // For encoding detection only. + const Ch* Peek4() const { + return Tell() + 4 <= size_ ? src_ : 0; + } + + const Ch* src_; //!< Current read position. + const Ch* begin_; //!< Original head of the string. + const Ch* end_; //!< End of stream. + size_t size_; //!< Size of the stream. +}; + +RAPIDJSON_NAMESPACE_END + +#ifdef __clang__ +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_MEMORYBUFFER_H_ diff --git a/include/rapidjson/msinttypes/inttypes.h b/include/rapidjson/msinttypes/inttypes.h new file mode 100644 index 0000000000..18111286bf --- /dev/null +++ b/include/rapidjson/msinttypes/inttypes.h @@ -0,0 +1,316 @@ +// ISO C9x compliant inttypes.h for Microsoft Visual Studio +// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124 +// +// Copyright (c) 2006-2013 Alexander Chemeris +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the product nor the names of its contributors may +// be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED +// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO +// EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +/////////////////////////////////////////////////////////////////////////////// + +// The above software in this distribution may have been modified by +// THL A29 Limited ("Tencent Modifications"). +// All Tencent Modifications are Copyright (C) 2015 THL A29 Limited. + +#ifndef _MSC_VER // [ +#error "Use this header only with Microsoft Visual C++ compilers!" +#endif // _MSC_VER ] + +#ifndef _MSC_INTTYPES_H_ // [ +#define _MSC_INTTYPES_H_ + +#if _MSC_VER > 1000 +#pragma once +#endif + +#include "stdint.h" + +// miloyip: VC supports inttypes.h since VC2013 +#if _MSC_VER >= 1800 +#include +#else + +// 7.8 Format conversion of integer types + +typedef struct { + intmax_t quot; + intmax_t rem; +} imaxdiv_t; + +// 7.8.1 Macros for format specifiers + +#if !defined(__cplusplus) || defined(__STDC_FORMAT_MACROS) // [ See footnote 185 at page 198 + +// The fprintf macros for signed integers are: +#define PRId8 "d" +#define PRIi8 "i" +#define PRIdLEAST8 "d" +#define PRIiLEAST8 "i" +#define PRIdFAST8 "d" +#define PRIiFAST8 "i" + +#define PRId16 "hd" +#define PRIi16 "hi" +#define PRIdLEAST16 "hd" +#define PRIiLEAST16 "hi" +#define PRIdFAST16 "hd" +#define PRIiFAST16 "hi" + +#define PRId32 "I32d" +#define PRIi32 "I32i" +#define PRIdLEAST32 "I32d" +#define PRIiLEAST32 "I32i" +#define PRIdFAST32 "I32d" +#define PRIiFAST32 "I32i" + +#define PRId64 "I64d" +#define PRIi64 "I64i" +#define PRIdLEAST64 "I64d" +#define PRIiLEAST64 "I64i" +#define PRIdFAST64 "I64d" +#define PRIiFAST64 "I64i" + +#define PRIdMAX "I64d" +#define PRIiMAX "I64i" + +#define PRIdPTR "Id" +#define PRIiPTR "Ii" + +// The fprintf macros for unsigned integers are: +#define PRIo8 "o" +#define PRIu8 "u" +#define PRIx8 "x" +#define PRIX8 "X" +#define PRIoLEAST8 "o" +#define PRIuLEAST8 "u" +#define PRIxLEAST8 "x" +#define PRIXLEAST8 "X" +#define PRIoFAST8 "o" +#define PRIuFAST8 "u" +#define PRIxFAST8 "x" +#define PRIXFAST8 "X" + +#define PRIo16 "ho" +#define PRIu16 "hu" +#define PRIx16 "hx" +#define PRIX16 "hX" +#define PRIoLEAST16 "ho" +#define PRIuLEAST16 "hu" +#define PRIxLEAST16 "hx" +#define PRIXLEAST16 "hX" +#define PRIoFAST16 "ho" +#define PRIuFAST16 "hu" +#define PRIxFAST16 "hx" +#define PRIXFAST16 "hX" + +#define PRIo32 "I32o" +#define PRIu32 "I32u" +#define PRIx32 "I32x" +#define PRIX32 "I32X" +#define PRIoLEAST32 "I32o" +#define PRIuLEAST32 "I32u" +#define PRIxLEAST32 "I32x" +#define PRIXLEAST32 "I32X" +#define PRIoFAST32 "I32o" +#define PRIuFAST32 "I32u" +#define PRIxFAST32 "I32x" +#define PRIXFAST32 "I32X" + +#define PRIo64 "I64o" +#define PRIu64 "I64u" +#define PRIx64 "I64x" +#define PRIX64 "I64X" +#define PRIoLEAST64 "I64o" +#define PRIuLEAST64 "I64u" +#define PRIxLEAST64 "I64x" +#define PRIXLEAST64 "I64X" +#define PRIoFAST64 "I64o" +#define PRIuFAST64 "I64u" +#define PRIxFAST64 "I64x" +#define PRIXFAST64 "I64X" + +#define PRIoMAX "I64o" +#define PRIuMAX "I64u" +#define PRIxMAX "I64x" +#define PRIXMAX "I64X" + +#define PRIoPTR "Io" +#define PRIuPTR "Iu" +#define PRIxPTR "Ix" +#define PRIXPTR "IX" + +// The fscanf macros for signed integers are: +#define SCNd8 "d" +#define SCNi8 "i" +#define SCNdLEAST8 "d" +#define SCNiLEAST8 "i" +#define SCNdFAST8 "d" +#define SCNiFAST8 "i" + +#define SCNd16 "hd" +#define SCNi16 "hi" +#define SCNdLEAST16 "hd" +#define SCNiLEAST16 "hi" +#define SCNdFAST16 "hd" +#define SCNiFAST16 "hi" + +#define SCNd32 "ld" +#define SCNi32 "li" +#define SCNdLEAST32 "ld" +#define SCNiLEAST32 "li" +#define SCNdFAST32 "ld" +#define SCNiFAST32 "li" + +#define SCNd64 "I64d" +#define SCNi64 "I64i" +#define SCNdLEAST64 "I64d" +#define SCNiLEAST64 "I64i" +#define SCNdFAST64 "I64d" +#define SCNiFAST64 "I64i" + +#define SCNdMAX "I64d" +#define SCNiMAX "I64i" + +#ifdef _WIN64 // [ +# define SCNdPTR "I64d" +# define SCNiPTR "I64i" +#else // _WIN64 ][ +# define SCNdPTR "ld" +# define SCNiPTR "li" +#endif // _WIN64 ] + +// The fscanf macros for unsigned integers are: +#define SCNo8 "o" +#define SCNu8 "u" +#define SCNx8 "x" +#define SCNX8 "X" +#define SCNoLEAST8 "o" +#define SCNuLEAST8 "u" +#define SCNxLEAST8 "x" +#define SCNXLEAST8 "X" +#define SCNoFAST8 "o" +#define SCNuFAST8 "u" +#define SCNxFAST8 "x" +#define SCNXFAST8 "X" + +#define SCNo16 "ho" +#define SCNu16 "hu" +#define SCNx16 "hx" +#define SCNX16 "hX" +#define SCNoLEAST16 "ho" +#define SCNuLEAST16 "hu" +#define SCNxLEAST16 "hx" +#define SCNXLEAST16 "hX" +#define SCNoFAST16 "ho" +#define SCNuFAST16 "hu" +#define SCNxFAST16 "hx" +#define SCNXFAST16 "hX" + +#define SCNo32 "lo" +#define SCNu32 "lu" +#define SCNx32 "lx" +#define SCNX32 "lX" +#define SCNoLEAST32 "lo" +#define SCNuLEAST32 "lu" +#define SCNxLEAST32 "lx" +#define SCNXLEAST32 "lX" +#define SCNoFAST32 "lo" +#define SCNuFAST32 "lu" +#define SCNxFAST32 "lx" +#define SCNXFAST32 "lX" + +#define SCNo64 "I64o" +#define SCNu64 "I64u" +#define SCNx64 "I64x" +#define SCNX64 "I64X" +#define SCNoLEAST64 "I64o" +#define SCNuLEAST64 "I64u" +#define SCNxLEAST64 "I64x" +#define SCNXLEAST64 "I64X" +#define SCNoFAST64 "I64o" +#define SCNuFAST64 "I64u" +#define SCNxFAST64 "I64x" +#define SCNXFAST64 "I64X" + +#define SCNoMAX "I64o" +#define SCNuMAX "I64u" +#define SCNxMAX "I64x" +#define SCNXMAX "I64X" + +#ifdef _WIN64 // [ +# define SCNoPTR "I64o" +# define SCNuPTR "I64u" +# define SCNxPTR "I64x" +# define SCNXPTR "I64X" +#else // _WIN64 ][ +# define SCNoPTR "lo" +# define SCNuPTR "lu" +# define SCNxPTR "lx" +# define SCNXPTR "lX" +#endif // _WIN64 ] + +#endif // __STDC_FORMAT_MACROS ] + +// 7.8.2 Functions for greatest-width integer types + +// 7.8.2.1 The imaxabs function +#define imaxabs _abs64 + +// 7.8.2.2 The imaxdiv function + +// This is modified version of div() function from Microsoft's div.c found +// in %MSVC.NET%\crt\src\div.c +#ifdef STATIC_IMAXDIV // [ +static +#else // STATIC_IMAXDIV ][ +_inline +#endif // STATIC_IMAXDIV ] +imaxdiv_t __cdecl imaxdiv(intmax_t numer, intmax_t denom) +{ + imaxdiv_t result; + + result.quot = numer / denom; + result.rem = numer % denom; + + if (numer < 0 && result.rem > 0) { + // did division wrong; must fix up + ++result.quot; + result.rem -= denom; + } + + return result; +} + +// 7.8.2.3 The strtoimax and strtoumax functions +#define strtoimax _strtoi64 +#define strtoumax _strtoui64 + +// 7.8.2.4 The wcstoimax and wcstoumax functions +#define wcstoimax _wcstoi64 +#define wcstoumax _wcstoui64 + +#endif // _MSC_VER >= 1800 + +#endif // _MSC_INTTYPES_H_ ] diff --git a/include/rapidjson/msinttypes/stdint.h b/include/rapidjson/msinttypes/stdint.h new file mode 100644 index 0000000000..3d4477b9a0 --- /dev/null +++ b/include/rapidjson/msinttypes/stdint.h @@ -0,0 +1,300 @@ +// ISO C9x compliant stdint.h for Microsoft Visual Studio +// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124 +// +// Copyright (c) 2006-2013 Alexander Chemeris +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the product nor the names of its contributors may +// be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED +// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO +// EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +/////////////////////////////////////////////////////////////////////////////// + +// The above software in this distribution may have been modified by +// THL A29 Limited ("Tencent Modifications"). +// All Tencent Modifications are Copyright (C) 2015 THL A29 Limited. + +#ifndef _MSC_VER // [ +#error "Use this header only with Microsoft Visual C++ compilers!" +#endif // _MSC_VER ] + +#ifndef _MSC_STDINT_H_ // [ +#define _MSC_STDINT_H_ + +#if _MSC_VER > 1000 +#pragma once +#endif + +// miloyip: Originally Visual Studio 2010 uses its own stdint.h. However it generates warning with INT64_C(), so change to use this file for vs2010. +#if _MSC_VER >= 1600 // [ +#include + +#if !defined(__cplusplus) || defined(__STDC_CONSTANT_MACROS) // [ See footnote 224 at page 260 + +#undef INT8_C +#undef INT16_C +#undef INT32_C +#undef INT64_C +#undef UINT8_C +#undef UINT16_C +#undef UINT32_C +#undef UINT64_C + +// 7.18.4.1 Macros for minimum-width integer constants + +#define INT8_C(val) val##i8 +#define INT16_C(val) val##i16 +#define INT32_C(val) val##i32 +#define INT64_C(val) val##i64 + +#define UINT8_C(val) val##ui8 +#define UINT16_C(val) val##ui16 +#define UINT32_C(val) val##ui32 +#define UINT64_C(val) val##ui64 + +// 7.18.4.2 Macros for greatest-width integer constants +// These #ifndef's are needed to prevent collisions with . +// Check out Issue 9 for the details. +#ifndef INTMAX_C // [ +# define INTMAX_C INT64_C +#endif // INTMAX_C ] +#ifndef UINTMAX_C // [ +# define UINTMAX_C UINT64_C +#endif // UINTMAX_C ] + +#endif // __STDC_CONSTANT_MACROS ] + +#else // ] _MSC_VER >= 1700 [ + +#include + +// For Visual Studio 6 in C++ mode and for many Visual Studio versions when +// compiling for ARM we have to wrap include with 'extern "C++" {}' +// or compiler would give many errors like this: +// error C2733: second C linkage of overloaded function 'wmemchr' not allowed +#if defined(__cplusplus) && !defined(_M_ARM) +extern "C" { +#endif +# include +#if defined(__cplusplus) && !defined(_M_ARM) +} +#endif + +// Define _W64 macros to mark types changing their size, like intptr_t. +#ifndef _W64 +# if !defined(__midl) && (defined(_X86_) || defined(_M_IX86)) && _MSC_VER >= 1300 +# define _W64 __w64 +# else +# define _W64 +# endif +#endif + + +// 7.18.1 Integer types + +// 7.18.1.1 Exact-width integer types + +// Visual Studio 6 and Embedded Visual C++ 4 doesn't +// realize that, e.g. char has the same size as __int8 +// so we give up on __intX for them. +#if (_MSC_VER < 1300) + typedef signed char int8_t; + typedef signed short int16_t; + typedef signed int int32_t; + typedef unsigned char uint8_t; + typedef unsigned short uint16_t; + typedef unsigned int uint32_t; +#else + typedef signed __int8 int8_t; + typedef signed __int16 int16_t; + typedef signed __int32 int32_t; + typedef unsigned __int8 uint8_t; + typedef unsigned __int16 uint16_t; + typedef unsigned __int32 uint32_t; +#endif +typedef signed __int64 int64_t; +typedef unsigned __int64 uint64_t; + + +// 7.18.1.2 Minimum-width integer types +typedef int8_t int_least8_t; +typedef int16_t int_least16_t; +typedef int32_t int_least32_t; +typedef int64_t int_least64_t; +typedef uint8_t uint_least8_t; +typedef uint16_t uint_least16_t; +typedef uint32_t uint_least32_t; +typedef uint64_t uint_least64_t; + +// 7.18.1.3 Fastest minimum-width integer types +typedef int8_t int_fast8_t; +typedef int16_t int_fast16_t; +typedef int32_t int_fast32_t; +typedef int64_t int_fast64_t; +typedef uint8_t uint_fast8_t; +typedef uint16_t uint_fast16_t; +typedef uint32_t uint_fast32_t; +typedef uint64_t uint_fast64_t; + +// 7.18.1.4 Integer types capable of holding object pointers +#ifdef _WIN64 // [ + typedef signed __int64 intptr_t; + typedef unsigned __int64 uintptr_t; +#else // _WIN64 ][ + typedef _W64 signed int intptr_t; + typedef _W64 unsigned int uintptr_t; +#endif // _WIN64 ] + +// 7.18.1.5 Greatest-width integer types +typedef int64_t intmax_t; +typedef uint64_t uintmax_t; + + +// 7.18.2 Limits of specified-width integer types + +#if !defined(__cplusplus) || defined(__STDC_LIMIT_MACROS) // [ See footnote 220 at page 257 and footnote 221 at page 259 + +// 7.18.2.1 Limits of exact-width integer types +#define INT8_MIN ((int8_t)_I8_MIN) +#define INT8_MAX _I8_MAX +#define INT16_MIN ((int16_t)_I16_MIN) +#define INT16_MAX _I16_MAX +#define INT32_MIN ((int32_t)_I32_MIN) +#define INT32_MAX _I32_MAX +#define INT64_MIN ((int64_t)_I64_MIN) +#define INT64_MAX _I64_MAX +#define UINT8_MAX _UI8_MAX +#define UINT16_MAX _UI16_MAX +#define UINT32_MAX _UI32_MAX +#define UINT64_MAX _UI64_MAX + +// 7.18.2.2 Limits of minimum-width integer types +#define INT_LEAST8_MIN INT8_MIN +#define INT_LEAST8_MAX INT8_MAX +#define INT_LEAST16_MIN INT16_MIN +#define INT_LEAST16_MAX INT16_MAX +#define INT_LEAST32_MIN INT32_MIN +#define INT_LEAST32_MAX INT32_MAX +#define INT_LEAST64_MIN INT64_MIN +#define INT_LEAST64_MAX INT64_MAX +#define UINT_LEAST8_MAX UINT8_MAX +#define UINT_LEAST16_MAX UINT16_MAX +#define UINT_LEAST32_MAX UINT32_MAX +#define UINT_LEAST64_MAX UINT64_MAX + +// 7.18.2.3 Limits of fastest minimum-width integer types +#define INT_FAST8_MIN INT8_MIN +#define INT_FAST8_MAX INT8_MAX +#define INT_FAST16_MIN INT16_MIN +#define INT_FAST16_MAX INT16_MAX +#define INT_FAST32_MIN INT32_MIN +#define INT_FAST32_MAX INT32_MAX +#define INT_FAST64_MIN INT64_MIN +#define INT_FAST64_MAX INT64_MAX +#define UINT_FAST8_MAX UINT8_MAX +#define UINT_FAST16_MAX UINT16_MAX +#define UINT_FAST32_MAX UINT32_MAX +#define UINT_FAST64_MAX UINT64_MAX + +// 7.18.2.4 Limits of integer types capable of holding object pointers +#ifdef _WIN64 // [ +# define INTPTR_MIN INT64_MIN +# define INTPTR_MAX INT64_MAX +# define UINTPTR_MAX UINT64_MAX +#else // _WIN64 ][ +# define INTPTR_MIN INT32_MIN +# define INTPTR_MAX INT32_MAX +# define UINTPTR_MAX UINT32_MAX +#endif // _WIN64 ] + +// 7.18.2.5 Limits of greatest-width integer types +#define INTMAX_MIN INT64_MIN +#define INTMAX_MAX INT64_MAX +#define UINTMAX_MAX UINT64_MAX + +// 7.18.3 Limits of other integer types + +#ifdef _WIN64 // [ +# define PTRDIFF_MIN _I64_MIN +# define PTRDIFF_MAX _I64_MAX +#else // _WIN64 ][ +# define PTRDIFF_MIN _I32_MIN +# define PTRDIFF_MAX _I32_MAX +#endif // _WIN64 ] + +#define SIG_ATOMIC_MIN INT_MIN +#define SIG_ATOMIC_MAX INT_MAX + +#ifndef SIZE_MAX // [ +# ifdef _WIN64 // [ +# define SIZE_MAX _UI64_MAX +# else // _WIN64 ][ +# define SIZE_MAX _UI32_MAX +# endif // _WIN64 ] +#endif // SIZE_MAX ] + +// WCHAR_MIN and WCHAR_MAX are also defined in +#ifndef WCHAR_MIN // [ +# define WCHAR_MIN 0 +#endif // WCHAR_MIN ] +#ifndef WCHAR_MAX // [ +# define WCHAR_MAX _UI16_MAX +#endif // WCHAR_MAX ] + +#define WINT_MIN 0 +#define WINT_MAX _UI16_MAX + +#endif // __STDC_LIMIT_MACROS ] + + +// 7.18.4 Limits of other integer types + +#if !defined(__cplusplus) || defined(__STDC_CONSTANT_MACROS) // [ See footnote 224 at page 260 + +// 7.18.4.1 Macros for minimum-width integer constants + +#define INT8_C(val) val##i8 +#define INT16_C(val) val##i16 +#define INT32_C(val) val##i32 +#define INT64_C(val) val##i64 + +#define UINT8_C(val) val##ui8 +#define UINT16_C(val) val##ui16 +#define UINT32_C(val) val##ui32 +#define UINT64_C(val) val##ui64 + +// 7.18.4.2 Macros for greatest-width integer constants +// These #ifndef's are needed to prevent collisions with . +// Check out Issue 9 for the details. +#ifndef INTMAX_C // [ +# define INTMAX_C INT64_C +#endif // INTMAX_C ] +#ifndef UINTMAX_C // [ +# define UINTMAX_C UINT64_C +#endif // UINTMAX_C ] + +#endif // __STDC_CONSTANT_MACROS ] + +#endif // _MSC_VER >= 1600 ] + +#endif // _MSC_STDINT_H_ ] diff --git a/include/rapidjson/ostreamwrapper.h b/include/rapidjson/ostreamwrapper.h new file mode 100644 index 0000000000..11ed4d33f9 --- /dev/null +++ b/include/rapidjson/ostreamwrapper.h @@ -0,0 +1,81 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_OSTREAMWRAPPER_H_ +#define RAPIDJSON_OSTREAMWRAPPER_H_ + +#include "stream.h" +#include + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(padded) +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +//! Wrapper of \c std::basic_ostream into RapidJSON's Stream concept. +/*! + The classes can be wrapped including but not limited to: + + - \c std::ostringstream + - \c std::stringstream + - \c std::wpstringstream + - \c std::wstringstream + - \c std::ifstream + - \c std::fstream + - \c std::wofstream + - \c std::wfstream + + \tparam StreamType Class derived from \c std::basic_ostream. +*/ + +template +class BasicOStreamWrapper { +public: + typedef typename StreamType::char_type Ch; + BasicOStreamWrapper(StreamType& stream) : stream_(stream) {} + + void Put(Ch c) { + stream_.put(c); + } + + void Flush() { + stream_.flush(); + } + + // Not implemented + char Peek() const { RAPIDJSON_ASSERT(false); return 0; } + char Take() { RAPIDJSON_ASSERT(false); return 0; } + size_t Tell() const { RAPIDJSON_ASSERT(false); return 0; } + char* PutBegin() { RAPIDJSON_ASSERT(false); return 0; } + size_t PutEnd(char*) { RAPIDJSON_ASSERT(false); return 0; } + +private: + BasicOStreamWrapper(const BasicOStreamWrapper&); + BasicOStreamWrapper& operator=(const BasicOStreamWrapper&); + + StreamType& stream_; +}; + +typedef BasicOStreamWrapper OStreamWrapper; +typedef BasicOStreamWrapper WOStreamWrapper; + +#ifdef __clang__ +RAPIDJSON_DIAG_POP +#endif + +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_OSTREAMWRAPPER_H_ diff --git a/include/rapidjson/pointer.h b/include/rapidjson/pointer.h new file mode 100644 index 0000000000..355929ede0 --- /dev/null +++ b/include/rapidjson/pointer.h @@ -0,0 +1,1482 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_POINTER_H_ +#define RAPIDJSON_POINTER_H_ + +#include "document.h" +#include "uri.h" +#include "internal/itoa.h" +#include "error/error.h" // PointerParseErrorCode + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(switch-enum) +#elif defined(_MSC_VER) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated +#endif + +#if defined(RAPIDJSON_CPLUSPLUS) && RAPIDJSON_CPLUSPLUS >= 201703L +#define RAPIDJSON_IF_CONSTEXPR if constexpr +#else +#define RAPIDJSON_IF_CONSTEXPR if +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +static const SizeType kPointerInvalidIndex = ~SizeType(0); //!< Represents an invalid index in GenericPointer::Token + +/////////////////////////////////////////////////////////////////////////////// +// GenericPointer + +//! Represents a JSON Pointer. Use Pointer for UTF8 encoding and default allocator. +/*! + This class implements RFC 6901 "JavaScript Object Notation (JSON) Pointer" + (https://tools.ietf.org/html/rfc6901). + + A JSON pointer is for identifying a specific value in a JSON document + (GenericDocument). It can simplify coding of DOM tree manipulation, because it + can access multiple-level depth of DOM tree with single API call. + + After it parses a string representation (e.g. "/foo/0" or URI fragment + representation (e.g. "#/foo/0") into its internal representation (tokens), + it can be used to resolve a specific value in multiple documents, or sub-tree + of documents. + + Contrary to GenericValue, Pointer can be copy constructed and copy assigned. + Apart from assignment, a Pointer cannot be modified after construction. + + Although Pointer is very convenient, please aware that constructing Pointer + involves parsing and dynamic memory allocation. A special constructor with user- + supplied tokens eliminates these. + + GenericPointer depends on GenericDocument and GenericValue. + + \tparam ValueType The value type of the DOM tree. E.g. GenericValue > + \tparam Allocator The allocator type for allocating memory for internal representation. + + \note GenericPointer uses same encoding of ValueType. + However, Allocator of GenericPointer is independent of Allocator of Value. +*/ +template +class GenericPointer { +public: + typedef typename ValueType::EncodingType EncodingType; //!< Encoding type from Value + typedef typename ValueType::Ch Ch; //!< Character type from Value + typedef GenericUri UriType; + + + //! A token is the basic units of internal representation. + /*! + A JSON pointer string representation "/foo/123" is parsed to two tokens: + "foo" and 123. 123 will be represented in both numeric form and string form. + They are resolved according to the actual value type (object or array). + + For token that are not numbers, or the numeric value is out of bound + (greater than limits of SizeType), they are only treated as string form + (i.e. the token's index will be equal to kPointerInvalidIndex). + + This struct is public so that user can create a Pointer without parsing and + allocation, using a special constructor. + */ + struct Token { + const Ch* name; //!< Name of the token. It has null character at the end but it can contain null character. + SizeType length; //!< Length of the name. + SizeType index; //!< A valid array index, if it is not equal to kPointerInvalidIndex. + }; + + //!@name Constructors and destructor. + //@{ + + //! Default constructor. + GenericPointer(Allocator* allocator = 0) : allocator_(allocator), ownAllocator_(), nameBuffer_(), tokens_(), tokenCount_(), parseErrorOffset_(), parseErrorCode_(kPointerParseErrorNone) {} + + //! Constructor that parses a string or URI fragment representation. + /*! + \param source A null-terminated, string or URI fragment representation of JSON pointer. + \param allocator User supplied allocator for this pointer. If no allocator is provided, it creates a self-owned one. + */ + explicit GenericPointer(const Ch* source, Allocator* allocator = 0) : allocator_(allocator), ownAllocator_(), nameBuffer_(), tokens_(), tokenCount_(), parseErrorOffset_(), parseErrorCode_(kPointerParseErrorNone) { + Parse(source, internal::StrLen(source)); + } + +#if RAPIDJSON_HAS_STDSTRING + //! Constructor that parses a string or URI fragment representation. + /*! + \param source A string or URI fragment representation of JSON pointer. + \param allocator User supplied allocator for this pointer. If no allocator is provided, it creates a self-owned one. + \note Requires the definition of the preprocessor symbol \ref RAPIDJSON_HAS_STDSTRING. + */ + explicit GenericPointer(const std::basic_string& source, Allocator* allocator = 0) : allocator_(allocator), ownAllocator_(), nameBuffer_(), tokens_(), tokenCount_(), parseErrorOffset_(), parseErrorCode_(kPointerParseErrorNone) { + Parse(source.c_str(), source.size()); + } +#endif + + //! Constructor that parses a string or URI fragment representation, with length of the source string. + /*! + \param source A string or URI fragment representation of JSON pointer. + \param length Length of source. + \param allocator User supplied allocator for this pointer. If no allocator is provided, it creates a self-owned one. + \note Slightly faster than the overload without length. + */ + GenericPointer(const Ch* source, size_t length, Allocator* allocator = 0) : allocator_(allocator), ownAllocator_(), nameBuffer_(), tokens_(), tokenCount_(), parseErrorOffset_(), parseErrorCode_(kPointerParseErrorNone) { + Parse(source, length); + } + + //! Constructor with user-supplied tokens. + /*! + This constructor let user supplies const array of tokens. + This prevents the parsing process and eliminates allocation. + This is preferred for memory constrained environments. + + \param tokens An constant array of tokens representing the JSON pointer. + \param tokenCount Number of tokens. + + \b Example + \code + #define NAME(s) { s, sizeof(s) / sizeof(s[0]) - 1, kPointerInvalidIndex } + #define INDEX(i) { #i, sizeof(#i) - 1, i } + + static const Pointer::Token kTokens[] = { NAME("foo"), INDEX(123) }; + static const Pointer p(kTokens, sizeof(kTokens) / sizeof(kTokens[0])); + // Equivalent to static const Pointer p("/foo/123"); + + #undef NAME + #undef INDEX + \endcode + */ + GenericPointer(const Token* tokens, size_t tokenCount) : allocator_(), ownAllocator_(), nameBuffer_(), tokens_(const_cast(tokens)), tokenCount_(tokenCount), parseErrorOffset_(), parseErrorCode_(kPointerParseErrorNone) {} + + //! Copy constructor. + GenericPointer(const GenericPointer& rhs) : allocator_(), ownAllocator_(), nameBuffer_(), tokens_(), tokenCount_(), parseErrorOffset_(), parseErrorCode_(kPointerParseErrorNone) { + *this = rhs; + } + + //! Copy constructor. + GenericPointer(const GenericPointer& rhs, Allocator* allocator) : allocator_(allocator), ownAllocator_(), nameBuffer_(), tokens_(), tokenCount_(), parseErrorOffset_(), parseErrorCode_(kPointerParseErrorNone) { + *this = rhs; + } + + //! Destructor. + ~GenericPointer() { + if (nameBuffer_) // If user-supplied tokens constructor is used, nameBuffer_ is nullptr and tokens_ are not deallocated. + Allocator::Free(tokens_); + RAPIDJSON_DELETE(ownAllocator_); + } + + //! Assignment operator. + GenericPointer& operator=(const GenericPointer& rhs) { + if (this != &rhs) { + // Do not delete ownAllcator + if (nameBuffer_) + Allocator::Free(tokens_); + + tokenCount_ = rhs.tokenCount_; + parseErrorOffset_ = rhs.parseErrorOffset_; + parseErrorCode_ = rhs.parseErrorCode_; + + if (rhs.nameBuffer_) + CopyFromRaw(rhs); // Normally parsed tokens. + else { + tokens_ = rhs.tokens_; // User supplied const tokens. + nameBuffer_ = 0; + } + } + return *this; + } + + //! Swap the content of this pointer with an other. + /*! + \param other The pointer to swap with. + \note Constant complexity. + */ + GenericPointer& Swap(GenericPointer& other) RAPIDJSON_NOEXCEPT { + internal::Swap(allocator_, other.allocator_); + internal::Swap(ownAllocator_, other.ownAllocator_); + internal::Swap(nameBuffer_, other.nameBuffer_); + internal::Swap(tokens_, other.tokens_); + internal::Swap(tokenCount_, other.tokenCount_); + internal::Swap(parseErrorOffset_, other.parseErrorOffset_); + internal::Swap(parseErrorCode_, other.parseErrorCode_); + return *this; + } + + //! free-standing swap function helper + /*! + Helper function to enable support for common swap implementation pattern based on \c std::swap: + \code + void swap(MyClass& a, MyClass& b) { + using std::swap; + swap(a.pointer, b.pointer); + // ... + } + \endcode + \see Swap() + */ + friend inline void swap(GenericPointer& a, GenericPointer& b) RAPIDJSON_NOEXCEPT { a.Swap(b); } + + //@} + + //!@name Append token + //@{ + + //! Append a token and return a new Pointer + /*! + \param token Token to be appended. + \param allocator Allocator for the newly return Pointer. + \return A new Pointer with appended token. + */ + GenericPointer Append(const Token& token, Allocator* allocator = 0) const { + GenericPointer r; + r.allocator_ = allocator; + Ch *p = r.CopyFromRaw(*this, 1, token.length + 1); + std::memcpy(p, token.name, (token.length + 1) * sizeof(Ch)); + r.tokens_[tokenCount_].name = p; + r.tokens_[tokenCount_].length = token.length; + r.tokens_[tokenCount_].index = token.index; + return r; + } + + //! Append a name token with length, and return a new Pointer + /*! + \param name Name to be appended. + \param length Length of name. + \param allocator Allocator for the newly return Pointer. + \return A new Pointer with appended token. + */ + GenericPointer Append(const Ch* name, SizeType length, Allocator* allocator = 0) const { + Token token = { name, length, kPointerInvalidIndex }; + return Append(token, allocator); + } + + //! Append a name token without length, and return a new Pointer + /*! + \param name Name (const Ch*) to be appended. + \param allocator Allocator for the newly return Pointer. + \return A new Pointer with appended token. + */ + template + RAPIDJSON_DISABLEIF_RETURN((internal::NotExpr::Type, Ch> >), (GenericPointer)) + Append(T* name, Allocator* allocator = 0) const { + return Append(name, internal::StrLen(name), allocator); + } + +#if RAPIDJSON_HAS_STDSTRING + //! Append a name token, and return a new Pointer + /*! + \param name Name to be appended. + \param allocator Allocator for the newly return Pointer. + \return A new Pointer with appended token. + */ + GenericPointer Append(const std::basic_string& name, Allocator* allocator = 0) const { + return Append(name.c_str(), static_cast(name.size()), allocator); + } +#endif + + //! Append a index token, and return a new Pointer + /*! + \param index Index to be appended. + \param allocator Allocator for the newly return Pointer. + \return A new Pointer with appended token. + */ + GenericPointer Append(SizeType index, Allocator* allocator = 0) const { + char buffer[21]; + char* end = sizeof(SizeType) == 4 ? internal::u32toa(index, buffer) : internal::u64toa(index, buffer); + SizeType length = static_cast(end - buffer); + buffer[length] = '\0'; + + RAPIDJSON_IF_CONSTEXPR (sizeof(Ch) == 1) { + Token token = { reinterpret_cast(buffer), length, index }; + return Append(token, allocator); + } + else { + Ch name[21]; + for (size_t i = 0; i <= length; i++) + name[i] = static_cast(buffer[i]); + Token token = { name, length, index }; + return Append(token, allocator); + } + } + + //! Append a token by value, and return a new Pointer + /*! + \param token token to be appended. + \param allocator Allocator for the newly return Pointer. + \return A new Pointer with appended token. + */ + GenericPointer Append(const ValueType& token, Allocator* allocator = 0) const { + if (token.IsString()) + return Append(token.GetString(), token.GetStringLength(), allocator); + else { + RAPIDJSON_ASSERT(token.IsUint64()); + RAPIDJSON_ASSERT(token.GetUint64() <= SizeType(~0)); + return Append(static_cast(token.GetUint64()), allocator); + } + } + + //!@name Handling Parse Error + //@{ + + //! Check whether this is a valid pointer. + bool IsValid() const { return parseErrorCode_ == kPointerParseErrorNone; } + + //! Get the parsing error offset in code unit. + size_t GetParseErrorOffset() const { return parseErrorOffset_; } + + //! Get the parsing error code. + PointerParseErrorCode GetParseErrorCode() const { return parseErrorCode_; } + + //@} + + //! Get the allocator of this pointer. + Allocator& GetAllocator() { return *allocator_; } + + //!@name Tokens + //@{ + + //! Get the token array (const version only). + const Token* GetTokens() const { return tokens_; } + + //! Get the number of tokens. + size_t GetTokenCount() const { return tokenCount_; } + + //@} + + //!@name Equality/inequality operators + //@{ + + //! Equality operator. + /*! + \note When any pointers are invalid, always returns false. + */ + bool operator==(const GenericPointer& rhs) const { + if (!IsValid() || !rhs.IsValid() || tokenCount_ != rhs.tokenCount_) + return false; + + for (size_t i = 0; i < tokenCount_; i++) { + if (tokens_[i].index != rhs.tokens_[i].index || + tokens_[i].length != rhs.tokens_[i].length || + (tokens_[i].length != 0 && std::memcmp(tokens_[i].name, rhs.tokens_[i].name, sizeof(Ch)* tokens_[i].length) != 0)) + { + return false; + } + } + + return true; + } + + //! Inequality operator. + /*! + \note When any pointers are invalid, always returns true. + */ + bool operator!=(const GenericPointer& rhs) const { return !(*this == rhs); } + + //! Less than operator. + /*! + \note Invalid pointers are always greater than valid ones. + */ + bool operator<(const GenericPointer& rhs) const { + if (!IsValid()) + return false; + if (!rhs.IsValid()) + return true; + + if (tokenCount_ != rhs.tokenCount_) + return tokenCount_ < rhs.tokenCount_; + + for (size_t i = 0; i < tokenCount_; i++) { + if (tokens_[i].index != rhs.tokens_[i].index) + return tokens_[i].index < rhs.tokens_[i].index; + + if (tokens_[i].length != rhs.tokens_[i].length) + return tokens_[i].length < rhs.tokens_[i].length; + + if (int cmp = std::memcmp(tokens_[i].name, rhs.tokens_[i].name, sizeof(Ch) * tokens_[i].length)) + return cmp < 0; + } + + return false; + } + + //@} + + //!@name Stringify + //@{ + + //! Stringify the pointer into string representation. + /*! + \tparam OutputStream Type of output stream. + \param os The output stream. + */ + template + bool Stringify(OutputStream& os) const { + return Stringify(os); + } + + //! Stringify the pointer into URI fragment representation. + /*! + \tparam OutputStream Type of output stream. + \param os The output stream. + */ + template + bool StringifyUriFragment(OutputStream& os) const { + return Stringify(os); + } + + //@} + + //!@name Create value + //@{ + + //! Create a value in a subtree. + /*! + If the value is not exist, it creates all parent values and a JSON Null value. + So it always succeed and return the newly created or existing value. + + Remind that it may change types of parents according to tokens, so it + potentially removes previously stored values. For example, if a document + was an array, and "/foo" is used to create a value, then the document + will be changed to an object, and all existing array elements are lost. + + \param root Root value of a DOM subtree to be resolved. It can be any value other than document root. + \param allocator Allocator for creating the values if the specified value or its parents are not exist. + \param alreadyExist If non-null, it stores whether the resolved value is already exist. + \return The resolved newly created (a JSON Null value), or already exists value. + */ + ValueType& Create(ValueType& root, typename ValueType::AllocatorType& allocator, bool* alreadyExist = 0) const { + RAPIDJSON_ASSERT(IsValid()); + ValueType* v = &root; + bool exist = true; + for (const Token *t = tokens_; t != tokens_ + tokenCount_; ++t) { + if (v->IsArray() && t->name[0] == '-' && t->length == 1) { + v->PushBack(ValueType().Move(), allocator); + v = &((*v)[v->Size() - 1]); + exist = false; + } + else { + if (t->index == kPointerInvalidIndex) { // must be object name + if (!v->IsObject()) + v->SetObject(); // Change to Object + } + else { // object name or array index + if (!v->IsArray() && !v->IsObject()) + v->SetArray(); // Change to Array + } + + if (v->IsArray()) { + if (t->index >= v->Size()) { + v->Reserve(t->index + 1, allocator); + while (t->index >= v->Size()) + v->PushBack(ValueType().Move(), allocator); + exist = false; + } + v = &((*v)[t->index]); + } + else { + typename ValueType::MemberIterator m = v->FindMember(GenericValue(GenericStringRef(t->name, t->length))); + if (m == v->MemberEnd()) { + v->AddMember(ValueType(t->name, t->length, allocator).Move(), ValueType().Move(), allocator); + m = v->MemberEnd(); + v = &(--m)->value; // Assumes AddMember() appends at the end + exist = false; + } + else + v = &m->value; + } + } + } + + if (alreadyExist) + *alreadyExist = exist; + + return *v; + } + + //! Creates a value in a document. + /*! + \param document A document to be resolved. + \param alreadyExist If non-null, it stores whether the resolved value is already exist. + \return The resolved newly created, or already exists value. + */ + template + ValueType& Create(GenericDocument& document, bool* alreadyExist = 0) const { + return Create(document, document.GetAllocator(), alreadyExist); + } + + //@} + + //!@name Compute URI + //@{ + + //! Compute the in-scope URI for a subtree. + // For use with JSON pointers into JSON schema documents. + /*! + \param root Root value of a DOM sub-tree to be resolved. It can be any value other than document root. + \param rootUri Root URI + \param unresolvedTokenIndex If the pointer cannot resolve a token in the pointer, this parameter can obtain the index of unresolved token. + \param allocator Allocator for Uris + \return Uri if it can be resolved. Otherwise null. + + \note + There are only 3 situations when a URI cannot be resolved: + 1. A value in the path is not an array nor object. + 2. An object value does not contain the token. + 3. A token is out of range of an array value. + + Use unresolvedTokenIndex to retrieve the token index. + */ + UriType GetUri(ValueType& root, const UriType& rootUri, size_t* unresolvedTokenIndex = 0, Allocator* allocator = 0) const { + static const Ch kIdString[] = { 'i', 'd', '\0' }; + static const ValueType kIdValue(kIdString, 2); + UriType base = UriType(rootUri, allocator); + RAPIDJSON_ASSERT(IsValid()); + ValueType* v = &root; + for (const Token *t = tokens_; t != tokens_ + tokenCount_; ++t) { + switch (v->GetType()) { + case kObjectType: + { + // See if we have an id, and if so resolve with the current base + typename ValueType::MemberIterator m = v->FindMember(kIdValue); + if (m != v->MemberEnd() && (m->value).IsString()) { + UriType here = UriType(m->value, allocator).Resolve(base, allocator); + base = here; + } + m = v->FindMember(GenericValue(GenericStringRef(t->name, t->length))); + if (m == v->MemberEnd()) + break; + v = &m->value; + } + continue; + case kArrayType: + if (t->index == kPointerInvalidIndex || t->index >= v->Size()) + break; + v = &((*v)[t->index]); + continue; + default: + break; + } + + // Error: unresolved token + if (unresolvedTokenIndex) + *unresolvedTokenIndex = static_cast(t - tokens_); + return UriType(allocator); + } + return base; + } + + UriType GetUri(const ValueType& root, const UriType& rootUri, size_t* unresolvedTokenIndex = 0, Allocator* allocator = 0) const { + return GetUri(const_cast(root), rootUri, unresolvedTokenIndex, allocator); + } + + + //!@name Query value + //@{ + + //! Query a value in a subtree. + /*! + \param root Root value of a DOM sub-tree to be resolved. It can be any value other than document root. + \param unresolvedTokenIndex If the pointer cannot resolve a token in the pointer, this parameter can obtain the index of unresolved token. + \return Pointer to the value if it can be resolved. Otherwise null. + + \note + There are only 3 situations when a value cannot be resolved: + 1. A value in the path is not an array nor object. + 2. An object value does not contain the token. + 3. A token is out of range of an array value. + + Use unresolvedTokenIndex to retrieve the token index. + */ + ValueType* Get(ValueType& root, size_t* unresolvedTokenIndex = 0) const { + RAPIDJSON_ASSERT(IsValid()); + ValueType* v = &root; + for (const Token *t = tokens_; t != tokens_ + tokenCount_; ++t) { + switch (v->GetType()) { + case kObjectType: + { + typename ValueType::MemberIterator m = v->FindMember(GenericValue(GenericStringRef(t->name, t->length))); + if (m == v->MemberEnd()) + break; + v = &m->value; + } + continue; + case kArrayType: + if (t->index == kPointerInvalidIndex || t->index >= v->Size()) + break; + v = &((*v)[t->index]); + continue; + default: + break; + } + + // Error: unresolved token + if (unresolvedTokenIndex) + *unresolvedTokenIndex = static_cast(t - tokens_); + return 0; + } + return v; + } + + //! Query a const value in a const subtree. + /*! + \param root Root value of a DOM sub-tree to be resolved. It can be any value other than document root. + \return Pointer to the value if it can be resolved. Otherwise null. + */ + const ValueType* Get(const ValueType& root, size_t* unresolvedTokenIndex = 0) const { + return Get(const_cast(root), unresolvedTokenIndex); + } + + //@} + + //!@name Query a value with default + //@{ + + //! Query a value in a subtree with default value. + /*! + Similar to Get(), but if the specified value do not exists, it creates all parents and clone the default value. + So that this function always succeed. + + \param root Root value of a DOM sub-tree to be resolved. It can be any value other than document root. + \param defaultValue Default value to be cloned if the value was not exists. + \param allocator Allocator for creating the values if the specified value or its parents are not exist. + \see Create() + */ + ValueType& GetWithDefault(ValueType& root, const ValueType& defaultValue, typename ValueType::AllocatorType& allocator) const { + bool alreadyExist; + ValueType& v = Create(root, allocator, &alreadyExist); + return alreadyExist ? v : v.CopyFrom(defaultValue, allocator); + } + + //! Query a value in a subtree with default null-terminated string. + ValueType& GetWithDefault(ValueType& root, const Ch* defaultValue, typename ValueType::AllocatorType& allocator) const { + bool alreadyExist; + ValueType& v = Create(root, allocator, &alreadyExist); + return alreadyExist ? v : v.SetString(defaultValue, allocator); + } + +#if RAPIDJSON_HAS_STDSTRING + //! Query a value in a subtree with default std::basic_string. + ValueType& GetWithDefault(ValueType& root, const std::basic_string& defaultValue, typename ValueType::AllocatorType& allocator) const { + bool alreadyExist; + ValueType& v = Create(root, allocator, &alreadyExist); + return alreadyExist ? v : v.SetString(defaultValue, allocator); + } +#endif + + //! Query a value in a subtree with default primitive value. + /*! + \tparam T Either \ref Type, \c int, \c unsigned, \c int64_t, \c uint64_t, \c bool + */ + template + RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (ValueType&)) + GetWithDefault(ValueType& root, T defaultValue, typename ValueType::AllocatorType& allocator) const { + return GetWithDefault(root, ValueType(defaultValue).Move(), allocator); + } + + //! Query a value in a document with default value. + template + ValueType& GetWithDefault(GenericDocument& document, const ValueType& defaultValue) const { + return GetWithDefault(document, defaultValue, document.GetAllocator()); + } + + //! Query a value in a document with default null-terminated string. + template + ValueType& GetWithDefault(GenericDocument& document, const Ch* defaultValue) const { + return GetWithDefault(document, defaultValue, document.GetAllocator()); + } + +#if RAPIDJSON_HAS_STDSTRING + //! Query a value in a document with default std::basic_string. + template + ValueType& GetWithDefault(GenericDocument& document, const std::basic_string& defaultValue) const { + return GetWithDefault(document, defaultValue, document.GetAllocator()); + } +#endif + + //! Query a value in a document with default primitive value. + /*! + \tparam T Either \ref Type, \c int, \c unsigned, \c int64_t, \c uint64_t, \c bool + */ + template + RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (ValueType&)) + GetWithDefault(GenericDocument& document, T defaultValue) const { + return GetWithDefault(document, defaultValue, document.GetAllocator()); + } + + //@} + + //!@name Set a value + //@{ + + //! Set a value in a subtree, with move semantics. + /*! + It creates all parents if they are not exist or types are different to the tokens. + So this function always succeeds but potentially remove existing values. + + \param root Root value of a DOM sub-tree to be resolved. It can be any value other than document root. + \param value Value to be set. + \param allocator Allocator for creating the values if the specified value or its parents are not exist. + \see Create() + */ + ValueType& Set(ValueType& root, ValueType& value, typename ValueType::AllocatorType& allocator) const { + return Create(root, allocator) = value; + } + + //! Set a value in a subtree, with copy semantics. + ValueType& Set(ValueType& root, const ValueType& value, typename ValueType::AllocatorType& allocator) const { + return Create(root, allocator).CopyFrom(value, allocator); + } + + //! Set a null-terminated string in a subtree. + ValueType& Set(ValueType& root, const Ch* value, typename ValueType::AllocatorType& allocator) const { + return Create(root, allocator) = ValueType(value, allocator).Move(); + } + +#if RAPIDJSON_HAS_STDSTRING + //! Set a std::basic_string in a subtree. + ValueType& Set(ValueType& root, const std::basic_string& value, typename ValueType::AllocatorType& allocator) const { + return Create(root, allocator) = ValueType(value, allocator).Move(); + } +#endif + + //! Set a primitive value in a subtree. + /*! + \tparam T Either \ref Type, \c int, \c unsigned, \c int64_t, \c uint64_t, \c bool + */ + template + RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (ValueType&)) + Set(ValueType& root, T value, typename ValueType::AllocatorType& allocator) const { + return Create(root, allocator) = ValueType(value).Move(); + } + + //! Set a value in a document, with move semantics. + template + ValueType& Set(GenericDocument& document, ValueType& value) const { + return Create(document) = value; + } + + //! Set a value in a document, with copy semantics. + template + ValueType& Set(GenericDocument& document, const ValueType& value) const { + return Create(document).CopyFrom(value, document.GetAllocator()); + } + + //! Set a null-terminated string in a document. + template + ValueType& Set(GenericDocument& document, const Ch* value) const { + return Create(document) = ValueType(value, document.GetAllocator()).Move(); + } + +#if RAPIDJSON_HAS_STDSTRING + //! Sets a std::basic_string in a document. + template + ValueType& Set(GenericDocument& document, const std::basic_string& value) const { + return Create(document) = ValueType(value, document.GetAllocator()).Move(); + } +#endif + + //! Set a primitive value in a document. + /*! + \tparam T Either \ref Type, \c int, \c unsigned, \c int64_t, \c uint64_t, \c bool + */ + template + RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (ValueType&)) + Set(GenericDocument& document, T value) const { + return Create(document) = value; + } + + //@} + + //!@name Swap a value + //@{ + + //! Swap a value with a value in a subtree. + /*! + It creates all parents if they are not exist or types are different to the tokens. + So this function always succeeds but potentially remove existing values. + + \param root Root value of a DOM sub-tree to be resolved. It can be any value other than document root. + \param value Value to be swapped. + \param allocator Allocator for creating the values if the specified value or its parents are not exist. + \see Create() + */ + ValueType& Swap(ValueType& root, ValueType& value, typename ValueType::AllocatorType& allocator) const { + return Create(root, allocator).Swap(value); + } + + //! Swap a value with a value in a document. + template + ValueType& Swap(GenericDocument& document, ValueType& value) const { + return Create(document).Swap(value); + } + + //@} + + //! Erase a value in a subtree. + /*! + \param root Root value of a DOM sub-tree to be resolved. It can be any value other than document root. + \return Whether the resolved value is found and erased. + + \note Erasing with an empty pointer \c Pointer(""), i.e. the root, always fail and return false. + */ + bool Erase(ValueType& root) const { + RAPIDJSON_ASSERT(IsValid()); + if (tokenCount_ == 0) // Cannot erase the root + return false; + + ValueType* v = &root; + const Token* last = tokens_ + (tokenCount_ - 1); + for (const Token *t = tokens_; t != last; ++t) { + switch (v->GetType()) { + case kObjectType: + { + typename ValueType::MemberIterator m = v->FindMember(GenericValue(GenericStringRef(t->name, t->length))); + if (m == v->MemberEnd()) + return false; + v = &m->value; + } + break; + case kArrayType: + if (t->index == kPointerInvalidIndex || t->index >= v->Size()) + return false; + v = &((*v)[t->index]); + break; + default: + return false; + } + } + + switch (v->GetType()) { + case kObjectType: + return v->EraseMember(GenericStringRef(last->name, last->length)); + case kArrayType: + if (last->index == kPointerInvalidIndex || last->index >= v->Size()) + return false; + v->Erase(v->Begin() + last->index); + return true; + default: + return false; + } + } + +private: + //! Clone the content from rhs to this. + /*! + \param rhs Source pointer. + \param extraToken Extra tokens to be allocated. + \param extraNameBufferSize Extra name buffer size (in number of Ch) to be allocated. + \return Start of non-occupied name buffer, for storing extra names. + */ + Ch* CopyFromRaw(const GenericPointer& rhs, size_t extraToken = 0, size_t extraNameBufferSize = 0) { + if (!allocator_) // allocator is independently owned. + ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)(); + + size_t nameBufferSize = rhs.tokenCount_; // null terminators for tokens + for (Token *t = rhs.tokens_; t != rhs.tokens_ + rhs.tokenCount_; ++t) + nameBufferSize += t->length; + + tokenCount_ = rhs.tokenCount_ + extraToken; + tokens_ = static_cast(allocator_->Malloc(tokenCount_ * sizeof(Token) + (nameBufferSize + extraNameBufferSize) * sizeof(Ch))); + nameBuffer_ = reinterpret_cast(tokens_ + tokenCount_); + if (rhs.tokenCount_ > 0) { + std::memcpy(tokens_, rhs.tokens_, rhs.tokenCount_ * sizeof(Token)); + } + if (nameBufferSize > 0) { + std::memcpy(nameBuffer_, rhs.nameBuffer_, nameBufferSize * sizeof(Ch)); + } + + // The names of each token point to a string in the nameBuffer_. The + // previous memcpy copied over string pointers into the rhs.nameBuffer_, + // but they should point to the strings in the new nameBuffer_. + for (size_t i = 0; i < rhs.tokenCount_; ++i) { + // The offset between the string address and the name buffer should + // still be constant, so we can just get this offset and set each new + // token name according the new buffer start + the known offset. + std::ptrdiff_t name_offset = rhs.tokens_[i].name - rhs.nameBuffer_; + tokens_[i].name = nameBuffer_ + name_offset; + } + + return nameBuffer_ + nameBufferSize; + } + + //! Check whether a character should be percent-encoded. + /*! + According to RFC 3986 2.3 Unreserved Characters. + \param c The character (code unit) to be tested. + */ + bool NeedPercentEncode(Ch c) const { + return !((c >= '0' && c <= '9') || (c >= 'A' && c <='Z') || (c >= 'a' && c <= 'z') || c == '-' || c == '.' || c == '_' || c =='~'); + } + + //! Parse a JSON String or its URI fragment representation into tokens. +#ifndef __clang__ // -Wdocumentation + /*! + \param source Either a JSON Pointer string, or its URI fragment representation. Not need to be null terminated. + \param length Length of the source string. + \note Source cannot be JSON String Representation of JSON Pointer, e.g. In "/\u0000", \u0000 will not be unescaped. + */ +#endif + void Parse(const Ch* source, size_t length) { + RAPIDJSON_ASSERT(source != NULL); + RAPIDJSON_ASSERT(nameBuffer_ == 0); + RAPIDJSON_ASSERT(tokens_ == 0); + + // Create own allocator if user did not supply. + if (!allocator_) + ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)(); + + // Count number of '/' as tokenCount + tokenCount_ = 0; + for (const Ch* s = source; s != source + length; s++) + if (*s == '/') + tokenCount_++; + + Token* token = tokens_ = static_cast(allocator_->Malloc(tokenCount_ * sizeof(Token) + length * sizeof(Ch))); + Ch* name = nameBuffer_ = reinterpret_cast(tokens_ + tokenCount_); + size_t i = 0; + + // Detect if it is a URI fragment + bool uriFragment = false; + if (source[i] == '#') { + uriFragment = true; + i++; + } + + if (i != length && source[i] != '/') { + parseErrorCode_ = kPointerParseErrorTokenMustBeginWithSolidus; + goto error; + } + + while (i < length) { + RAPIDJSON_ASSERT(source[i] == '/'); + i++; // consumes '/' + + token->name = name; + bool isNumber = true; + + while (i < length && source[i] != '/') { + Ch c = source[i]; + if (uriFragment) { + // Decoding percent-encoding for URI fragment + if (c == '%') { + PercentDecodeStream is(&source[i], source + length); + GenericInsituStringStream os(name); + Ch* begin = os.PutBegin(); + if (!Transcoder, EncodingType>().Validate(is, os) || !is.IsValid()) { + parseErrorCode_ = kPointerParseErrorInvalidPercentEncoding; + goto error; + } + size_t len = os.PutEnd(begin); + i += is.Tell() - 1; + if (len == 1) + c = *name; + else { + name += len; + isNumber = false; + i++; + continue; + } + } + else if (NeedPercentEncode(c)) { + parseErrorCode_ = kPointerParseErrorCharacterMustPercentEncode; + goto error; + } + } + + i++; + + // Escaping "~0" -> '~', "~1" -> '/' + if (c == '~') { + if (i < length) { + c = source[i]; + if (c == '0') c = '~'; + else if (c == '1') c = '/'; + else { + parseErrorCode_ = kPointerParseErrorInvalidEscape; + goto error; + } + i++; + } + else { + parseErrorCode_ = kPointerParseErrorInvalidEscape; + goto error; + } + } + + // First check for index: all of characters are digit + if (c < '0' || c > '9') + isNumber = false; + + *name++ = c; + } + token->length = static_cast(name - token->name); + if (token->length == 0) + isNumber = false; + *name++ = '\0'; // Null terminator + + // Second check for index: more than one digit cannot have leading zero + if (isNumber && token->length > 1 && token->name[0] == '0') + isNumber = false; + + // String to SizeType conversion + SizeType n = 0; + if (isNumber) { + for (size_t j = 0; j < token->length; j++) { + SizeType m = n * 10 + static_cast(token->name[j] - '0'); + if (m < n) { // overflow detection + isNumber = false; + break; + } + n = m; + } + } + + token->index = isNumber ? n : kPointerInvalidIndex; + token++; + } + + RAPIDJSON_ASSERT(name <= nameBuffer_ + length); // Should not overflow buffer + parseErrorCode_ = kPointerParseErrorNone; + return; + + error: + Allocator::Free(tokens_); + nameBuffer_ = 0; + tokens_ = 0; + tokenCount_ = 0; + parseErrorOffset_ = i; + return; + } + + //! Stringify to string or URI fragment representation. + /*! + \tparam uriFragment True for stringifying to URI fragment representation. False for string representation. + \tparam OutputStream type of output stream. + \param os The output stream. + */ + template + bool Stringify(OutputStream& os) const { + RAPIDJSON_ASSERT(IsValid()); + + if (uriFragment) + os.Put('#'); + + for (Token *t = tokens_; t != tokens_ + tokenCount_; ++t) { + os.Put('/'); + for (size_t j = 0; j < t->length; j++) { + Ch c = t->name[j]; + if (c == '~') { + os.Put('~'); + os.Put('0'); + } + else if (c == '/') { + os.Put('~'); + os.Put('1'); + } + else if (uriFragment && NeedPercentEncode(c)) { + // Transcode to UTF8 sequence + GenericStringStream source(&t->name[j]); + PercentEncodeStream target(os); + if (!Transcoder >().Validate(source, target)) + return false; + j += source.Tell() - 1; + } + else + os.Put(c); + } + } + return true; + } + + //! A helper stream for decoding a percent-encoded sequence into code unit. + /*! + This stream decodes %XY triplet into code unit (0-255). + If it encounters invalid characters, it sets output code unit as 0 and + mark invalid, and to be checked by IsValid(). + */ + class PercentDecodeStream { + public: + typedef typename ValueType::Ch Ch; + + //! Constructor + /*! + \param source Start of the stream + \param end Past-the-end of the stream. + */ + PercentDecodeStream(const Ch* source, const Ch* end) : src_(source), head_(source), end_(end), valid_(true) {} + + Ch Take() { + if (*src_ != '%' || src_ + 3 > end_) { // %XY triplet + valid_ = false; + return 0; + } + src_++; + Ch c = 0; + for (int j = 0; j < 2; j++) { + c = static_cast(c << 4); + Ch h = *src_; + if (h >= '0' && h <= '9') c = static_cast(c + h - '0'); + else if (h >= 'A' && h <= 'F') c = static_cast(c + h - 'A' + 10); + else if (h >= 'a' && h <= 'f') c = static_cast(c + h - 'a' + 10); + else { + valid_ = false; + return 0; + } + src_++; + } + return c; + } + + size_t Tell() const { return static_cast(src_ - head_); } + bool IsValid() const { return valid_; } + + private: + const Ch* src_; //!< Current read position. + const Ch* head_; //!< Original head of the string. + const Ch* end_; //!< Past-the-end position. + bool valid_; //!< Whether the parsing is valid. + }; + + //! A helper stream to encode character (UTF-8 code unit) into percent-encoded sequence. + template + class PercentEncodeStream { + public: + PercentEncodeStream(OutputStream& os) : os_(os) {} + void Put(char c) { // UTF-8 must be byte + unsigned char u = static_cast(c); + static const char hexDigits[16] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' }; + os_.Put('%'); + os_.Put(static_cast(hexDigits[u >> 4])); + os_.Put(static_cast(hexDigits[u & 15])); + } + private: + OutputStream& os_; + }; + + Allocator* allocator_; //!< The current allocator. It is either user-supplied or equal to ownAllocator_. + Allocator* ownAllocator_; //!< Allocator owned by this Pointer. + Ch* nameBuffer_; //!< A buffer containing all names in tokens. + Token* tokens_; //!< A list of tokens. + size_t tokenCount_; //!< Number of tokens in tokens_. + size_t parseErrorOffset_; //!< Offset in code unit when parsing fail. + PointerParseErrorCode parseErrorCode_; //!< Parsing error code. +}; + +//! GenericPointer for Value (UTF-8, default allocator). +typedef GenericPointer Pointer; + +//!@name Helper functions for GenericPointer +//@{ + +////////////////////////////////////////////////////////////////////////////// + +template +typename T::ValueType& CreateValueByPointer(T& root, const GenericPointer& pointer, typename T::AllocatorType& a) { + return pointer.Create(root, a); +} + +template +typename T::ValueType& CreateValueByPointer(T& root, const CharType(&source)[N], typename T::AllocatorType& a) { + return GenericPointer(source, N - 1).Create(root, a); +} + +// No allocator parameter + +template +typename DocumentType::ValueType& CreateValueByPointer(DocumentType& document, const GenericPointer& pointer) { + return pointer.Create(document); +} + +template +typename DocumentType::ValueType& CreateValueByPointer(DocumentType& document, const CharType(&source)[N]) { + return GenericPointer(source, N - 1).Create(document); +} + +////////////////////////////////////////////////////////////////////////////// + +template +typename T::ValueType* GetValueByPointer(T& root, const GenericPointer& pointer, size_t* unresolvedTokenIndex = 0) { + return pointer.Get(root, unresolvedTokenIndex); +} + +template +const typename T::ValueType* GetValueByPointer(const T& root, const GenericPointer& pointer, size_t* unresolvedTokenIndex = 0) { + return pointer.Get(root, unresolvedTokenIndex); +} + +template +typename T::ValueType* GetValueByPointer(T& root, const CharType (&source)[N], size_t* unresolvedTokenIndex = 0) { + return GenericPointer(source, N - 1).Get(root, unresolvedTokenIndex); +} + +template +const typename T::ValueType* GetValueByPointer(const T& root, const CharType(&source)[N], size_t* unresolvedTokenIndex = 0) { + return GenericPointer(source, N - 1).Get(root, unresolvedTokenIndex); +} + +////////////////////////////////////////////////////////////////////////////// + +template +typename T::ValueType& GetValueByPointerWithDefault(T& root, const GenericPointer& pointer, const typename T::ValueType& defaultValue, typename T::AllocatorType& a) { + return pointer.GetWithDefault(root, defaultValue, a); +} + +template +typename T::ValueType& GetValueByPointerWithDefault(T& root, const GenericPointer& pointer, const typename T::Ch* defaultValue, typename T::AllocatorType& a) { + return pointer.GetWithDefault(root, defaultValue, a); +} + +#if RAPIDJSON_HAS_STDSTRING +template +typename T::ValueType& GetValueByPointerWithDefault(T& root, const GenericPointer& pointer, const std::basic_string& defaultValue, typename T::AllocatorType& a) { + return pointer.GetWithDefault(root, defaultValue, a); +} +#endif + +template +RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (typename T::ValueType&)) +GetValueByPointerWithDefault(T& root, const GenericPointer& pointer, T2 defaultValue, typename T::AllocatorType& a) { + return pointer.GetWithDefault(root, defaultValue, a); +} + +template +typename T::ValueType& GetValueByPointerWithDefault(T& root, const CharType(&source)[N], const typename T::ValueType& defaultValue, typename T::AllocatorType& a) { + return GenericPointer(source, N - 1).GetWithDefault(root, defaultValue, a); +} + +template +typename T::ValueType& GetValueByPointerWithDefault(T& root, const CharType(&source)[N], const typename T::Ch* defaultValue, typename T::AllocatorType& a) { + return GenericPointer(source, N - 1).GetWithDefault(root, defaultValue, a); +} + +#if RAPIDJSON_HAS_STDSTRING +template +typename T::ValueType& GetValueByPointerWithDefault(T& root, const CharType(&source)[N], const std::basic_string& defaultValue, typename T::AllocatorType& a) { + return GenericPointer(source, N - 1).GetWithDefault(root, defaultValue, a); +} +#endif + +template +RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (typename T::ValueType&)) +GetValueByPointerWithDefault(T& root, const CharType(&source)[N], T2 defaultValue, typename T::AllocatorType& a) { + return GenericPointer(source, N - 1).GetWithDefault(root, defaultValue, a); +} + +// No allocator parameter + +template +typename DocumentType::ValueType& GetValueByPointerWithDefault(DocumentType& document, const GenericPointer& pointer, const typename DocumentType::ValueType& defaultValue) { + return pointer.GetWithDefault(document, defaultValue); +} + +template +typename DocumentType::ValueType& GetValueByPointerWithDefault(DocumentType& document, const GenericPointer& pointer, const typename DocumentType::Ch* defaultValue) { + return pointer.GetWithDefault(document, defaultValue); +} + +#if RAPIDJSON_HAS_STDSTRING +template +typename DocumentType::ValueType& GetValueByPointerWithDefault(DocumentType& document, const GenericPointer& pointer, const std::basic_string& defaultValue) { + return pointer.GetWithDefault(document, defaultValue); +} +#endif + +template +RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (typename DocumentType::ValueType&)) +GetValueByPointerWithDefault(DocumentType& document, const GenericPointer& pointer, T2 defaultValue) { + return pointer.GetWithDefault(document, defaultValue); +} + +template +typename DocumentType::ValueType& GetValueByPointerWithDefault(DocumentType& document, const CharType(&source)[N], const typename DocumentType::ValueType& defaultValue) { + return GenericPointer(source, N - 1).GetWithDefault(document, defaultValue); +} + +template +typename DocumentType::ValueType& GetValueByPointerWithDefault(DocumentType& document, const CharType(&source)[N], const typename DocumentType::Ch* defaultValue) { + return GenericPointer(source, N - 1).GetWithDefault(document, defaultValue); +} + +#if RAPIDJSON_HAS_STDSTRING +template +typename DocumentType::ValueType& GetValueByPointerWithDefault(DocumentType& document, const CharType(&source)[N], const std::basic_string& defaultValue) { + return GenericPointer(source, N - 1).GetWithDefault(document, defaultValue); +} +#endif + +template +RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (typename DocumentType::ValueType&)) +GetValueByPointerWithDefault(DocumentType& document, const CharType(&source)[N], T2 defaultValue) { + return GenericPointer(source, N - 1).GetWithDefault(document, defaultValue); +} + +////////////////////////////////////////////////////////////////////////////// + +template +typename T::ValueType& SetValueByPointer(T& root, const GenericPointer& pointer, typename T::ValueType& value, typename T::AllocatorType& a) { + return pointer.Set(root, value, a); +} + +template +typename T::ValueType& SetValueByPointer(T& root, const GenericPointer& pointer, const typename T::ValueType& value, typename T::AllocatorType& a) { + return pointer.Set(root, value, a); +} + +template +typename T::ValueType& SetValueByPointer(T& root, const GenericPointer& pointer, const typename T::Ch* value, typename T::AllocatorType& a) { + return pointer.Set(root, value, a); +} + +#if RAPIDJSON_HAS_STDSTRING +template +typename T::ValueType& SetValueByPointer(T& root, const GenericPointer& pointer, const std::basic_string& value, typename T::AllocatorType& a) { + return pointer.Set(root, value, a); +} +#endif + +template +RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (typename T::ValueType&)) +SetValueByPointer(T& root, const GenericPointer& pointer, T2 value, typename T::AllocatorType& a) { + return pointer.Set(root, value, a); +} + +template +typename T::ValueType& SetValueByPointer(T& root, const CharType(&source)[N], typename T::ValueType& value, typename T::AllocatorType& a) { + return GenericPointer(source, N - 1).Set(root, value, a); +} + +template +typename T::ValueType& SetValueByPointer(T& root, const CharType(&source)[N], const typename T::ValueType& value, typename T::AllocatorType& a) { + return GenericPointer(source, N - 1).Set(root, value, a); +} + +template +typename T::ValueType& SetValueByPointer(T& root, const CharType(&source)[N], const typename T::Ch* value, typename T::AllocatorType& a) { + return GenericPointer(source, N - 1).Set(root, value, a); +} + +#if RAPIDJSON_HAS_STDSTRING +template +typename T::ValueType& SetValueByPointer(T& root, const CharType(&source)[N], const std::basic_string& value, typename T::AllocatorType& a) { + return GenericPointer(source, N - 1).Set(root, value, a); +} +#endif + +template +RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (typename T::ValueType&)) +SetValueByPointer(T& root, const CharType(&source)[N], T2 value, typename T::AllocatorType& a) { + return GenericPointer(source, N - 1).Set(root, value, a); +} + +// No allocator parameter + +template +typename DocumentType::ValueType& SetValueByPointer(DocumentType& document, const GenericPointer& pointer, typename DocumentType::ValueType& value) { + return pointer.Set(document, value); +} + +template +typename DocumentType::ValueType& SetValueByPointer(DocumentType& document, const GenericPointer& pointer, const typename DocumentType::ValueType& value) { + return pointer.Set(document, value); +} + +template +typename DocumentType::ValueType& SetValueByPointer(DocumentType& document, const GenericPointer& pointer, const typename DocumentType::Ch* value) { + return pointer.Set(document, value); +} + +#if RAPIDJSON_HAS_STDSTRING +template +typename DocumentType::ValueType& SetValueByPointer(DocumentType& document, const GenericPointer& pointer, const std::basic_string& value) { + return pointer.Set(document, value); +} +#endif + +template +RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (typename DocumentType::ValueType&)) +SetValueByPointer(DocumentType& document, const GenericPointer& pointer, T2 value) { + return pointer.Set(document, value); +} + +template +typename DocumentType::ValueType& SetValueByPointer(DocumentType& document, const CharType(&source)[N], typename DocumentType::ValueType& value) { + return GenericPointer(source, N - 1).Set(document, value); +} + +template +typename DocumentType::ValueType& SetValueByPointer(DocumentType& document, const CharType(&source)[N], const typename DocumentType::ValueType& value) { + return GenericPointer(source, N - 1).Set(document, value); +} + +template +typename DocumentType::ValueType& SetValueByPointer(DocumentType& document, const CharType(&source)[N], const typename DocumentType::Ch* value) { + return GenericPointer(source, N - 1).Set(document, value); +} + +#if RAPIDJSON_HAS_STDSTRING +template +typename DocumentType::ValueType& SetValueByPointer(DocumentType& document, const CharType(&source)[N], const std::basic_string& value) { + return GenericPointer(source, N - 1).Set(document, value); +} +#endif + +template +RAPIDJSON_DISABLEIF_RETURN((internal::OrExpr, internal::IsGenericValue >), (typename DocumentType::ValueType&)) +SetValueByPointer(DocumentType& document, const CharType(&source)[N], T2 value) { + return GenericPointer(source, N - 1).Set(document, value); +} + +////////////////////////////////////////////////////////////////////////////// + +template +typename T::ValueType& SwapValueByPointer(T& root, const GenericPointer& pointer, typename T::ValueType& value, typename T::AllocatorType& a) { + return pointer.Swap(root, value, a); +} + +template +typename T::ValueType& SwapValueByPointer(T& root, const CharType(&source)[N], typename T::ValueType& value, typename T::AllocatorType& a) { + return GenericPointer(source, N - 1).Swap(root, value, a); +} + +template +typename DocumentType::ValueType& SwapValueByPointer(DocumentType& document, const GenericPointer& pointer, typename DocumentType::ValueType& value) { + return pointer.Swap(document, value); +} + +template +typename DocumentType::ValueType& SwapValueByPointer(DocumentType& document, const CharType(&source)[N], typename DocumentType::ValueType& value) { + return GenericPointer(source, N - 1).Swap(document, value); +} + +////////////////////////////////////////////////////////////////////////////// + +template +bool EraseValueByPointer(T& root, const GenericPointer& pointer) { + return pointer.Erase(root); +} + +template +bool EraseValueByPointer(T& root, const CharType(&source)[N]) { + return GenericPointer(source, N - 1).Erase(root); +} + +//@} + +RAPIDJSON_NAMESPACE_END + +#if defined(__clang__) || defined(_MSC_VER) +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_POINTER_H_ diff --git a/include/rapidjson/prettywriter.h b/include/rapidjson/prettywriter.h new file mode 100644 index 0000000000..fe45df1d10 --- /dev/null +++ b/include/rapidjson/prettywriter.h @@ -0,0 +1,277 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_PRETTYWRITER_H_ +#define RAPIDJSON_PRETTYWRITER_H_ + +#include "writer.h" + +#ifdef __GNUC__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(effc++) +#endif + +#if defined(__clang__) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(c++98-compat) +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +//! Combination of PrettyWriter format flags. +/*! \see PrettyWriter::SetFormatOptions + */ +enum PrettyFormatOptions { + kFormatDefault = 0, //!< Default pretty formatting. + kFormatSingleLineArray = 1 //!< Format arrays on a single line. +}; + +//! Writer with indentation and spacing. +/*! + \tparam OutputStream Type of output os. + \tparam SourceEncoding Encoding of source string. + \tparam TargetEncoding Encoding of output stream. + \tparam StackAllocator Type of allocator for allocating memory of stack. +*/ +template, typename TargetEncoding = UTF8<>, typename StackAllocator = CrtAllocator, unsigned writeFlags = kWriteDefaultFlags> +class PrettyWriter : public Writer { +public: + typedef Writer Base; + typedef typename Base::Ch Ch; + + //! Constructor + /*! \param os Output stream. + \param allocator User supplied allocator. If it is null, it will create a private one. + \param levelDepth Initial capacity of stack. + */ + explicit PrettyWriter(OutputStream& os, StackAllocator* allocator = 0, size_t levelDepth = Base::kDefaultLevelDepth) : + Base(os, allocator, levelDepth), indentChar_(' '), indentCharCount_(4), formatOptions_(kFormatDefault) {} + + + explicit PrettyWriter(StackAllocator* allocator = 0, size_t levelDepth = Base::kDefaultLevelDepth) : + Base(allocator, levelDepth), indentChar_(' '), indentCharCount_(4), formatOptions_(kFormatDefault) {} + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + PrettyWriter(PrettyWriter&& rhs) : + Base(std::forward(rhs)), indentChar_(rhs.indentChar_), indentCharCount_(rhs.indentCharCount_), formatOptions_(rhs.formatOptions_) {} +#endif + + //! Set custom indentation. + /*! \param indentChar Character for indentation. Must be whitespace character (' ', '\\t', '\\n', '\\r'). + \param indentCharCount Number of indent characters for each indentation level. + \note The default indentation is 4 spaces. + */ + PrettyWriter& SetIndent(Ch indentChar, unsigned indentCharCount) { + RAPIDJSON_ASSERT(indentChar == ' ' || indentChar == '\t' || indentChar == '\n' || indentChar == '\r'); + indentChar_ = indentChar; + indentCharCount_ = indentCharCount; + return *this; + } + + //! Set pretty writer formatting options. + /*! \param options Formatting options. + */ + PrettyWriter& SetFormatOptions(PrettyFormatOptions options) { + formatOptions_ = options; + return *this; + } + + /*! @name Implementation of Handler + \see Handler + */ + //@{ + + bool Null() { PrettyPrefix(kNullType); return Base::EndValue(Base::WriteNull()); } + bool Bool(bool b) { PrettyPrefix(b ? kTrueType : kFalseType); return Base::EndValue(Base::WriteBool(b)); } + bool Int(int i) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteInt(i)); } + bool Uint(unsigned u) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteUint(u)); } + bool Int64(int64_t i64) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteInt64(i64)); } + bool Uint64(uint64_t u64) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteUint64(u64)); } + bool Double(double d) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteDouble(d)); } + + bool RawNumber(const Ch* str, SizeType length, bool copy = false) { + RAPIDJSON_ASSERT(str != 0); + (void)copy; + PrettyPrefix(kNumberType); + return Base::EndValue(Base::WriteString(str, length)); + } + + bool String(const Ch* str, SizeType length, bool copy = false) { + RAPIDJSON_ASSERT(str != 0); + (void)copy; + PrettyPrefix(kStringType); + return Base::EndValue(Base::WriteString(str, length)); + } + +#if RAPIDJSON_HAS_STDSTRING + bool String(const std::basic_string& str) { + return String(str.data(), SizeType(str.size())); + } +#endif + + bool StartObject() { + PrettyPrefix(kObjectType); + new (Base::level_stack_.template Push()) typename Base::Level(false); + return Base::WriteStartObject(); + } + + bool Key(const Ch* str, SizeType length, bool copy = false) { return String(str, length, copy); } + +#if RAPIDJSON_HAS_STDSTRING + bool Key(const std::basic_string& str) { + return Key(str.data(), SizeType(str.size())); + } +#endif + + bool EndObject(SizeType memberCount = 0) { + (void)memberCount; + RAPIDJSON_ASSERT(Base::level_stack_.GetSize() >= sizeof(typename Base::Level)); // not inside an Object + RAPIDJSON_ASSERT(!Base::level_stack_.template Top()->inArray); // currently inside an Array, not Object + RAPIDJSON_ASSERT(0 == Base::level_stack_.template Top()->valueCount % 2); // Object has a Key without a Value + + bool empty = Base::level_stack_.template Pop(1)->valueCount == 0; + + if (!empty) { + Base::os_->Put('\n'); + WriteIndent(); + } + bool ret = Base::EndValue(Base::WriteEndObject()); + (void)ret; + RAPIDJSON_ASSERT(ret == true); + if (Base::level_stack_.Empty()) // end of json text + Base::Flush(); + return true; + } + + bool StartArray() { + PrettyPrefix(kArrayType); + new (Base::level_stack_.template Push()) typename Base::Level(true); + return Base::WriteStartArray(); + } + + bool EndArray(SizeType memberCount = 0) { + (void)memberCount; + RAPIDJSON_ASSERT(Base::level_stack_.GetSize() >= sizeof(typename Base::Level)); + RAPIDJSON_ASSERT(Base::level_stack_.template Top()->inArray); + bool empty = Base::level_stack_.template Pop(1)->valueCount == 0; + + if (!empty && !(formatOptions_ & kFormatSingleLineArray)) { + Base::os_->Put('\n'); + WriteIndent(); + } + bool ret = Base::EndValue(Base::WriteEndArray()); + (void)ret; + RAPIDJSON_ASSERT(ret == true); + if (Base::level_stack_.Empty()) // end of json text + Base::Flush(); + return true; + } + + //@} + + /*! @name Convenience extensions */ + //@{ + + //! Simpler but slower overload. + bool String(const Ch* str) { return String(str, internal::StrLen(str)); } + bool Key(const Ch* str) { return Key(str, internal::StrLen(str)); } + + //@} + + //! Write a raw JSON value. + /*! + For user to write a stringified JSON as a value. + + \param json A well-formed JSON value. It should not contain null character within [0, length - 1] range. + \param length Length of the json. + \param type Type of the root of json. + \note When using PrettyWriter::RawValue(), the result json may not be indented correctly. + */ + bool RawValue(const Ch* json, size_t length, Type type) { + RAPIDJSON_ASSERT(json != 0); + PrettyPrefix(type); + return Base::EndValue(Base::WriteRawValue(json, length)); + } + +protected: + void PrettyPrefix(Type type) { + (void)type; + if (Base::level_stack_.GetSize() != 0) { // this value is not at root + typename Base::Level* level = Base::level_stack_.template Top(); + + if (level->inArray) { + if (level->valueCount > 0) { + Base::os_->Put(','); // add comma if it is not the first element in array + if (formatOptions_ & kFormatSingleLineArray) + Base::os_->Put(' '); + } + + if (!(formatOptions_ & kFormatSingleLineArray)) { + Base::os_->Put('\n'); + WriteIndent(); + } + } + else { // in object + if (level->valueCount > 0) { + if (level->valueCount % 2 == 0) { + Base::os_->Put(','); + Base::os_->Put('\n'); + } + else { + Base::os_->Put(':'); + Base::os_->Put(' '); + } + } + else + Base::os_->Put('\n'); + + if (level->valueCount % 2 == 0) + WriteIndent(); + } + if (!level->inArray && level->valueCount % 2 == 0) + RAPIDJSON_ASSERT(type == kStringType); // if it's in object, then even number should be a name + level->valueCount++; + } + else { + RAPIDJSON_ASSERT(!Base::hasRoot_); // Should only has one and only one root. + Base::hasRoot_ = true; + } + } + + void WriteIndent() { + size_t count = (Base::level_stack_.GetSize() / sizeof(typename Base::Level)) * indentCharCount_; + PutN(*Base::os_, static_cast(indentChar_), count); + } + + Ch indentChar_; + unsigned indentCharCount_; + PrettyFormatOptions formatOptions_; + +private: + // Prohibit copy constructor & assignment operator. + PrettyWriter(const PrettyWriter&); + PrettyWriter& operator=(const PrettyWriter&); +}; + +RAPIDJSON_NAMESPACE_END + +#if defined(__clang__) +RAPIDJSON_DIAG_POP +#endif + +#ifdef __GNUC__ +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_RAPIDJSON_H_ diff --git a/include/rapidjson/rapidjson.h b/include/rapidjson/rapidjson.h new file mode 100644 index 0000000000..247b8e68db --- /dev/null +++ b/include/rapidjson/rapidjson.h @@ -0,0 +1,741 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_RAPIDJSON_H_ +#define RAPIDJSON_RAPIDJSON_H_ + +/*!\file rapidjson.h + \brief common definitions and configuration + + \see RAPIDJSON_CONFIG + */ + +/*! \defgroup RAPIDJSON_CONFIG RapidJSON configuration + \brief Configuration macros for library features + + Some RapidJSON features are configurable to adapt the library to a wide + variety of platforms, environments and usage scenarios. Most of the + features can be configured in terms of overridden or predefined + preprocessor macros at compile-time. + + Some additional customization is available in the \ref RAPIDJSON_ERRORS APIs. + + \note These macros should be given on the compiler command-line + (where applicable) to avoid inconsistent values when compiling + different translation units of a single application. + */ + +#include // malloc(), realloc(), free(), size_t +#include // memset(), memcpy(), memmove(), memcmp() + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_VERSION_STRING +// +// ALWAYS synchronize the following 3 macros with corresponding variables in /CMakeLists.txt. +// + +//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN +// token stringification +#define RAPIDJSON_STRINGIFY(x) RAPIDJSON_DO_STRINGIFY(x) +#define RAPIDJSON_DO_STRINGIFY(x) #x + +// token concatenation +#define RAPIDJSON_JOIN(X, Y) RAPIDJSON_DO_JOIN(X, Y) +#define RAPIDJSON_DO_JOIN(X, Y) RAPIDJSON_DO_JOIN2(X, Y) +#define RAPIDJSON_DO_JOIN2(X, Y) X##Y +//!@endcond + +/*! \def RAPIDJSON_MAJOR_VERSION + \ingroup RAPIDJSON_CONFIG + \brief Major version of RapidJSON in integer. +*/ +/*! \def RAPIDJSON_MINOR_VERSION + \ingroup RAPIDJSON_CONFIG + \brief Minor version of RapidJSON in integer. +*/ +/*! \def RAPIDJSON_PATCH_VERSION + \ingroup RAPIDJSON_CONFIG + \brief Patch version of RapidJSON in integer. +*/ +/*! \def RAPIDJSON_VERSION_STRING + \ingroup RAPIDJSON_CONFIG + \brief Version of RapidJSON in ".." string format. +*/ +#define RAPIDJSON_MAJOR_VERSION 1 +#define RAPIDJSON_MINOR_VERSION 1 +#define RAPIDJSON_PATCH_VERSION 0 +#define RAPIDJSON_VERSION_STRING \ + RAPIDJSON_STRINGIFY(RAPIDJSON_MAJOR_VERSION.RAPIDJSON_MINOR_VERSION.RAPIDJSON_PATCH_VERSION) + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_NAMESPACE_(BEGIN|END) +/*! \def RAPIDJSON_NAMESPACE + \ingroup RAPIDJSON_CONFIG + \brief provide custom rapidjson namespace + + In order to avoid symbol clashes and/or "One Definition Rule" errors + between multiple inclusions of (different versions of) RapidJSON in + a single binary, users can customize the name of the main RapidJSON + namespace. + + In case of a single nesting level, defining \c RAPIDJSON_NAMESPACE + to a custom name (e.g. \c MyRapidJSON) is sufficient. If multiple + levels are needed, both \ref RAPIDJSON_NAMESPACE_BEGIN and \ref + RAPIDJSON_NAMESPACE_END need to be defined as well: + + \code + // in some .cpp file + #define RAPIDJSON_NAMESPACE my::rapidjson + #define RAPIDJSON_NAMESPACE_BEGIN namespace my { namespace rapidjson { + #define RAPIDJSON_NAMESPACE_END } } + #include "rapidjson/..." + \endcode + + \see rapidjson + */ +/*! \def RAPIDJSON_NAMESPACE_BEGIN + \ingroup RAPIDJSON_CONFIG + \brief provide custom rapidjson namespace (opening expression) + \see RAPIDJSON_NAMESPACE +*/ +/*! \def RAPIDJSON_NAMESPACE_END + \ingroup RAPIDJSON_CONFIG + \brief provide custom rapidjson namespace (closing expression) + \see RAPIDJSON_NAMESPACE +*/ +#ifndef RAPIDJSON_NAMESPACE +#define RAPIDJSON_NAMESPACE rapidjson +#endif +#ifndef RAPIDJSON_NAMESPACE_BEGIN +#define RAPIDJSON_NAMESPACE_BEGIN namespace RAPIDJSON_NAMESPACE { +#endif +#ifndef RAPIDJSON_NAMESPACE_END +#define RAPIDJSON_NAMESPACE_END } +#endif + +/////////////////////////////////////////////////////////////////////////////// +// __cplusplus macro + +//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN + +#if defined(_MSC_VER) +#define RAPIDJSON_CPLUSPLUS _MSVC_LANG +#else +#define RAPIDJSON_CPLUSPLUS __cplusplus +#endif + +//!@endcond + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_HAS_STDSTRING + +#ifndef RAPIDJSON_HAS_STDSTRING +#ifdef RAPIDJSON_DOXYGEN_RUNNING +#define RAPIDJSON_HAS_STDSTRING 1 // force generation of documentation +#else +#define RAPIDJSON_HAS_STDSTRING 0 // no std::string support by default +#endif +/*! \def RAPIDJSON_HAS_STDSTRING + \ingroup RAPIDJSON_CONFIG + \brief Enable RapidJSON support for \c std::string + + By defining this preprocessor symbol to \c 1, several convenience functions for using + \ref rapidjson::GenericValue with \c std::string are enabled, especially + for construction and comparison. + + \hideinitializer +*/ +#endif // !defined(RAPIDJSON_HAS_STDSTRING) + +#if RAPIDJSON_HAS_STDSTRING +#include +#endif // RAPIDJSON_HAS_STDSTRING + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_USE_MEMBERSMAP + +/*! \def RAPIDJSON_USE_MEMBERSMAP + \ingroup RAPIDJSON_CONFIG + \brief Enable RapidJSON support for object members handling in a \c std::multimap + + By defining this preprocessor symbol to \c 1, \ref rapidjson::GenericValue object + members are stored in a \c std::multimap for faster lookup and deletion times, a + trade off with a slightly slower insertion time and a small object allocat(or)ed + memory overhead. + + \hideinitializer +*/ +#ifndef RAPIDJSON_USE_MEMBERSMAP +#define RAPIDJSON_USE_MEMBERSMAP 0 // not by default +#endif + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_NO_INT64DEFINE + +/*! \def RAPIDJSON_NO_INT64DEFINE + \ingroup RAPIDJSON_CONFIG + \brief Use external 64-bit integer types. + + RapidJSON requires the 64-bit integer types \c int64_t and \c uint64_t types + to be available at global scope. + + If users have their own definition, define RAPIDJSON_NO_INT64DEFINE to + prevent RapidJSON from defining its own types. +*/ +#ifndef RAPIDJSON_NO_INT64DEFINE +//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN +#if defined(_MSC_VER) && (_MSC_VER < 1800) // Visual Studio 2013 +#include "msinttypes/stdint.h" +#include "msinttypes/inttypes.h" +#else +// Other compilers should have this. +#include +#include +#endif +//!@endcond +#ifdef RAPIDJSON_DOXYGEN_RUNNING +#define RAPIDJSON_NO_INT64DEFINE +#endif +#endif // RAPIDJSON_NO_INT64TYPEDEF + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_FORCEINLINE + +#ifndef RAPIDJSON_FORCEINLINE +//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN +#if defined(_MSC_VER) && defined(NDEBUG) +#define RAPIDJSON_FORCEINLINE __forceinline +#elif defined(__GNUC__) && __GNUC__ >= 4 && defined(NDEBUG) +#define RAPIDJSON_FORCEINLINE __attribute__((always_inline)) +#else +#define RAPIDJSON_FORCEINLINE +#endif +//!@endcond +#endif // RAPIDJSON_FORCEINLINE + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_ENDIAN +#define RAPIDJSON_LITTLEENDIAN 0 //!< Little endian machine +#define RAPIDJSON_BIGENDIAN 1 //!< Big endian machine + +//! Endianness of the machine. +/*! + \def RAPIDJSON_ENDIAN + \ingroup RAPIDJSON_CONFIG + + GCC 4.6 provided macro for detecting endianness of the target machine. But other + compilers may not have this. User can define RAPIDJSON_ENDIAN to either + \ref RAPIDJSON_LITTLEENDIAN or \ref RAPIDJSON_BIGENDIAN. + + Default detection implemented with reference to + \li https://gcc.gnu.org/onlinedocs/gcc-4.6.0/cpp/Common-Predefined-Macros.html + \li http://www.boost.org/doc/libs/1_42_0/boost/detail/endian.hpp +*/ +#ifndef RAPIDJSON_ENDIAN +// Detect with GCC 4.6's macro +# ifdef __BYTE_ORDER__ +# if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN +# elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +# define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN +# else +# error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN. +# endif // __BYTE_ORDER__ +// Detect with GLIBC's endian.h +# elif defined(__GLIBC__) +# include +# if (__BYTE_ORDER == __LITTLE_ENDIAN) +# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN +# elif (__BYTE_ORDER == __BIG_ENDIAN) +# define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN +# else +# error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN. +# endif // __GLIBC__ +// Detect with _LITTLE_ENDIAN and _BIG_ENDIAN macro +# elif defined(_LITTLE_ENDIAN) && !defined(_BIG_ENDIAN) +# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN +# elif defined(_BIG_ENDIAN) && !defined(_LITTLE_ENDIAN) +# define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN +// Detect with architecture macros +# elif defined(__sparc) || defined(__sparc__) || defined(_POWER) || defined(__powerpc__) || defined(__ppc__) || defined(__ppc64__) || defined(__hpux) || defined(__hppa) || defined(_MIPSEB) || defined(_POWER) || defined(__s390__) +# define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN +# elif defined(__i386__) || defined(__alpha__) || defined(__ia64) || defined(__ia64__) || defined(_M_IX86) || defined(_M_IA64) || defined(_M_ALPHA) || defined(__amd64) || defined(__amd64__) || defined(_M_AMD64) || defined(__x86_64) || defined(__x86_64__) || defined(_M_X64) || defined(__bfin__) +# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN +# elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64)) +# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN +# elif defined(RAPIDJSON_DOXYGEN_RUNNING) +# define RAPIDJSON_ENDIAN +# else +# error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN. +# endif +#endif // RAPIDJSON_ENDIAN + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_64BIT + +//! Whether using 64-bit architecture +#ifndef RAPIDJSON_64BIT +#if defined(__LP64__) || (defined(__x86_64__) && defined(__ILP32__)) || defined(_WIN64) || defined(__EMSCRIPTEN__) +#define RAPIDJSON_64BIT 1 +#else +#define RAPIDJSON_64BIT 0 +#endif +#endif // RAPIDJSON_64BIT + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_ALIGN + +//! Data alignment of the machine. +/*! \ingroup RAPIDJSON_CONFIG + \param x pointer to align + + Some machines require strict data alignment. The default is 8 bytes. + User can customize by defining the RAPIDJSON_ALIGN function macro. +*/ +#ifndef RAPIDJSON_ALIGN +#define RAPIDJSON_ALIGN(x) (((x) + static_cast(7u)) & ~static_cast(7u)) +#endif + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_UINT64_C2 + +//! Construct a 64-bit literal by a pair of 32-bit integer. +/*! + 64-bit literal with or without ULL suffix is prone to compiler warnings. + UINT64_C() is C macro which cause compilation problems. + Use this macro to define 64-bit constants by a pair of 32-bit integer. +*/ +#ifndef RAPIDJSON_UINT64_C2 +#define RAPIDJSON_UINT64_C2(high32, low32) ((static_cast(high32) << 32) | static_cast(low32)) +#endif + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_48BITPOINTER_OPTIMIZATION + +//! Use only lower 48-bit address for some pointers. +/*! + \ingroup RAPIDJSON_CONFIG + + This optimization uses the fact that current X86-64 architecture only implement lower 48-bit virtual address. + The higher 16-bit can be used for storing other data. + \c GenericValue uses this optimization to reduce its size form 24 bytes to 16 bytes in 64-bit architecture. +*/ +#ifndef RAPIDJSON_48BITPOINTER_OPTIMIZATION +#if defined(__amd64__) || defined(__amd64) || defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) || defined(_M_AMD64) +#define RAPIDJSON_48BITPOINTER_OPTIMIZATION 1 +#else +#define RAPIDJSON_48BITPOINTER_OPTIMIZATION 0 +#endif +#endif // RAPIDJSON_48BITPOINTER_OPTIMIZATION + +#if RAPIDJSON_48BITPOINTER_OPTIMIZATION == 1 +#if RAPIDJSON_64BIT != 1 +#error RAPIDJSON_48BITPOINTER_OPTIMIZATION can only be set to 1 when RAPIDJSON_64BIT=1 +#endif +#define RAPIDJSON_SETPOINTER(type, p, x) (p = reinterpret_cast((reinterpret_cast(p) & static_cast(RAPIDJSON_UINT64_C2(0xFFFF0000, 0x00000000))) | reinterpret_cast(reinterpret_cast(x)))) +#define RAPIDJSON_GETPOINTER(type, p) (reinterpret_cast(reinterpret_cast(p) & static_cast(RAPIDJSON_UINT64_C2(0x0000FFFF, 0xFFFFFFFF)))) +#else +#define RAPIDJSON_SETPOINTER(type, p, x) (p = (x)) +#define RAPIDJSON_GETPOINTER(type, p) (p) +#endif + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_SSE2/RAPIDJSON_SSE42/RAPIDJSON_NEON/RAPIDJSON_SIMD + +/*! \def RAPIDJSON_SIMD + \ingroup RAPIDJSON_CONFIG + \brief Enable SSE2/SSE4.2/Neon optimization. + + RapidJSON supports optimized implementations for some parsing operations + based on the SSE2, SSE4.2 or NEon SIMD extensions on modern Intel + or ARM compatible processors. + + To enable these optimizations, three different symbols can be defined; + \code + // Enable SSE2 optimization. + #define RAPIDJSON_SSE2 + + // Enable SSE4.2 optimization. + #define RAPIDJSON_SSE42 + \endcode + + // Enable ARM Neon optimization. + #define RAPIDJSON_NEON + \endcode + + \c RAPIDJSON_SSE42 takes precedence over SSE2, if both are defined. + + If any of these symbols is defined, RapidJSON defines the macro + \c RAPIDJSON_SIMD to indicate the availability of the optimized code. +*/ +#if defined(RAPIDJSON_SSE2) || defined(RAPIDJSON_SSE42) \ + || defined(RAPIDJSON_NEON) || defined(RAPIDJSON_DOXYGEN_RUNNING) +#define RAPIDJSON_SIMD +#endif + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_NO_SIZETYPEDEFINE + +#ifndef RAPIDJSON_NO_SIZETYPEDEFINE +/*! \def RAPIDJSON_NO_SIZETYPEDEFINE + \ingroup RAPIDJSON_CONFIG + \brief User-provided \c SizeType definition. + + In order to avoid using 32-bit size types for indexing strings and arrays, + define this preprocessor symbol and provide the type rapidjson::SizeType + before including RapidJSON: + \code + #define RAPIDJSON_NO_SIZETYPEDEFINE + namespace rapidjson { typedef ::std::size_t SizeType; } + #include "rapidjson/..." + \endcode + + \see rapidjson::SizeType +*/ +#ifdef RAPIDJSON_DOXYGEN_RUNNING +#define RAPIDJSON_NO_SIZETYPEDEFINE +#endif +RAPIDJSON_NAMESPACE_BEGIN +//! Size type (for string lengths, array sizes, etc.) +/*! RapidJSON uses 32-bit array/string indices even on 64-bit platforms, + instead of using \c size_t. Users may override the SizeType by defining + \ref RAPIDJSON_NO_SIZETYPEDEFINE. +*/ +typedef unsigned SizeType; +RAPIDJSON_NAMESPACE_END +#endif + +// always import std::size_t to rapidjson namespace +RAPIDJSON_NAMESPACE_BEGIN +using std::size_t; +RAPIDJSON_NAMESPACE_END + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_ASSERT + +//! Assertion. +/*! \ingroup RAPIDJSON_CONFIG + By default, rapidjson uses C \c assert() for internal assertions. + User can override it by defining RAPIDJSON_ASSERT(x) macro. + + \note Parsing errors are handled and can be customized by the + \ref RAPIDJSON_ERRORS APIs. +*/ +#ifndef RAPIDJSON_ASSERT +#include +#define RAPIDJSON_ASSERT(x) assert(x) +#endif // RAPIDJSON_ASSERT + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_STATIC_ASSERT + +// Prefer C++11 static_assert, if available +#ifndef RAPIDJSON_STATIC_ASSERT +#if RAPIDJSON_CPLUSPLUS >= 201103L || ( defined(_MSC_VER) && _MSC_VER >= 1800 ) +#define RAPIDJSON_STATIC_ASSERT(x) \ + static_assert(x, RAPIDJSON_STRINGIFY(x)) +#endif // C++11 +#endif // RAPIDJSON_STATIC_ASSERT + +// Adopt C++03 implementation from boost +#ifndef RAPIDJSON_STATIC_ASSERT +#ifndef __clang__ +//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN +#endif +RAPIDJSON_NAMESPACE_BEGIN +template struct STATIC_ASSERTION_FAILURE; +template <> struct STATIC_ASSERTION_FAILURE { enum { value = 1 }; }; +template struct StaticAssertTest {}; +RAPIDJSON_NAMESPACE_END + +#if defined(__GNUC__) || defined(__clang__) +#define RAPIDJSON_STATIC_ASSERT_UNUSED_ATTRIBUTE __attribute__((unused)) +#else +#define RAPIDJSON_STATIC_ASSERT_UNUSED_ATTRIBUTE +#endif +#ifndef __clang__ +//!@endcond +#endif + +/*! \def RAPIDJSON_STATIC_ASSERT + \brief (Internal) macro to check for conditions at compile-time + \param x compile-time condition + \hideinitializer + */ +#define RAPIDJSON_STATIC_ASSERT(x) \ + typedef ::RAPIDJSON_NAMESPACE::StaticAssertTest< \ + sizeof(::RAPIDJSON_NAMESPACE::STATIC_ASSERTION_FAILURE)> \ + RAPIDJSON_JOIN(StaticAssertTypedef, __LINE__) RAPIDJSON_STATIC_ASSERT_UNUSED_ATTRIBUTE +#endif // RAPIDJSON_STATIC_ASSERT + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_LIKELY, RAPIDJSON_UNLIKELY + +//! Compiler branching hint for expression with high probability to be true. +/*! + \ingroup RAPIDJSON_CONFIG + \param x Boolean expression likely to be true. +*/ +#ifndef RAPIDJSON_LIKELY +#if defined(__GNUC__) || defined(__clang__) +#define RAPIDJSON_LIKELY(x) __builtin_expect(!!(x), 1) +#else +#define RAPIDJSON_LIKELY(x) (x) +#endif +#endif + +//! Compiler branching hint for expression with low probability to be true. +/*! + \ingroup RAPIDJSON_CONFIG + \param x Boolean expression unlikely to be true. +*/ +#ifndef RAPIDJSON_UNLIKELY +#if defined(__GNUC__) || defined(__clang__) +#define RAPIDJSON_UNLIKELY(x) __builtin_expect(!!(x), 0) +#else +#define RAPIDJSON_UNLIKELY(x) (x) +#endif +#endif + +/////////////////////////////////////////////////////////////////////////////// +// Helpers + +//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN + +#define RAPIDJSON_MULTILINEMACRO_BEGIN do { +#define RAPIDJSON_MULTILINEMACRO_END \ +} while((void)0, 0) + +// adopted from Boost +#define RAPIDJSON_VERSION_CODE(x,y,z) \ + (((x)*100000) + ((y)*100) + (z)) + +#if defined(__has_builtin) +#define RAPIDJSON_HAS_BUILTIN(x) __has_builtin(x) +#else +#define RAPIDJSON_HAS_BUILTIN(x) 0 +#endif + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_DIAG_PUSH/POP, RAPIDJSON_DIAG_OFF + +#if defined(__GNUC__) +#define RAPIDJSON_GNUC \ + RAPIDJSON_VERSION_CODE(__GNUC__,__GNUC_MINOR__,__GNUC_PATCHLEVEL__) +#endif + +#if defined(__clang__) || (defined(RAPIDJSON_GNUC) && RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,2,0)) + +#define RAPIDJSON_PRAGMA(x) _Pragma(RAPIDJSON_STRINGIFY(x)) +#define RAPIDJSON_DIAG_PRAGMA(x) RAPIDJSON_PRAGMA(GCC diagnostic x) +#define RAPIDJSON_DIAG_OFF(x) \ + RAPIDJSON_DIAG_PRAGMA(ignored RAPIDJSON_STRINGIFY(RAPIDJSON_JOIN(-W,x))) + +// push/pop support in Clang and GCC>=4.6 +#if defined(__clang__) || (defined(RAPIDJSON_GNUC) && RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,6,0)) +#define RAPIDJSON_DIAG_PUSH RAPIDJSON_DIAG_PRAGMA(push) +#define RAPIDJSON_DIAG_POP RAPIDJSON_DIAG_PRAGMA(pop) +#else // GCC >= 4.2, < 4.6 +#define RAPIDJSON_DIAG_PUSH /* ignored */ +#define RAPIDJSON_DIAG_POP /* ignored */ +#endif + +#elif defined(_MSC_VER) + +// pragma (MSVC specific) +#define RAPIDJSON_PRAGMA(x) __pragma(x) +#define RAPIDJSON_DIAG_PRAGMA(x) RAPIDJSON_PRAGMA(warning(x)) + +#define RAPIDJSON_DIAG_OFF(x) RAPIDJSON_DIAG_PRAGMA(disable: x) +#define RAPIDJSON_DIAG_PUSH RAPIDJSON_DIAG_PRAGMA(push) +#define RAPIDJSON_DIAG_POP RAPIDJSON_DIAG_PRAGMA(pop) + +#else + +#define RAPIDJSON_DIAG_OFF(x) /* ignored */ +#define RAPIDJSON_DIAG_PUSH /* ignored */ +#define RAPIDJSON_DIAG_POP /* ignored */ + +#endif // RAPIDJSON_DIAG_* + +/////////////////////////////////////////////////////////////////////////////// +// C++11 features + +#ifndef RAPIDJSON_HAS_CXX11 +#define RAPIDJSON_HAS_CXX11 (RAPIDJSON_CPLUSPLUS >= 201103L) +#endif + +#ifndef RAPIDJSON_HAS_CXX11_RVALUE_REFS +#if RAPIDJSON_HAS_CXX11 +#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 1 +#elif defined(__clang__) +#if __has_feature(cxx_rvalue_references) && \ + (defined(_MSC_VER) || defined(_LIBCPP_VERSION) || defined(__GLIBCXX__) && __GLIBCXX__ >= 20080306) +#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 1 +#else +#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 0 +#endif +#elif (defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,3,0)) && defined(__GXX_EXPERIMENTAL_CXX0X__)) || \ + (defined(_MSC_VER) && _MSC_VER >= 1600) || \ + (defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__)) + +#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 1 +#else +#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 0 +#endif +#endif // RAPIDJSON_HAS_CXX11_RVALUE_REFS + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS +#include // std::move +#endif + +#ifndef RAPIDJSON_HAS_CXX11_NOEXCEPT +#if RAPIDJSON_HAS_CXX11 +#define RAPIDJSON_HAS_CXX11_NOEXCEPT 1 +#elif defined(__clang__) +#define RAPIDJSON_HAS_CXX11_NOEXCEPT __has_feature(cxx_noexcept) +#elif (defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,6,0)) && defined(__GXX_EXPERIMENTAL_CXX0X__)) || \ + (defined(_MSC_VER) && _MSC_VER >= 1900) || \ + (defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__)) +#define RAPIDJSON_HAS_CXX11_NOEXCEPT 1 +#else +#define RAPIDJSON_HAS_CXX11_NOEXCEPT 0 +#endif +#endif +#ifndef RAPIDJSON_NOEXCEPT +#if RAPIDJSON_HAS_CXX11_NOEXCEPT +#define RAPIDJSON_NOEXCEPT noexcept +#else +#define RAPIDJSON_NOEXCEPT throw() +#endif // RAPIDJSON_HAS_CXX11_NOEXCEPT +#endif + +// no automatic detection, yet +#ifndef RAPIDJSON_HAS_CXX11_TYPETRAITS +#if (defined(_MSC_VER) && _MSC_VER >= 1700) +#define RAPIDJSON_HAS_CXX11_TYPETRAITS 1 +#else +#define RAPIDJSON_HAS_CXX11_TYPETRAITS 0 +#endif +#endif + +#ifndef RAPIDJSON_HAS_CXX11_RANGE_FOR +#if defined(__clang__) +#define RAPIDJSON_HAS_CXX11_RANGE_FOR __has_feature(cxx_range_for) +#elif (defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,6,0)) && defined(__GXX_EXPERIMENTAL_CXX0X__)) || \ + (defined(_MSC_VER) && _MSC_VER >= 1700) || \ + (defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__)) +#define RAPIDJSON_HAS_CXX11_RANGE_FOR 1 +#else +#define RAPIDJSON_HAS_CXX11_RANGE_FOR 0 +#endif +#endif // RAPIDJSON_HAS_CXX11_RANGE_FOR + +/////////////////////////////////////////////////////////////////////////////// +// C++17 features + +#ifndef RAPIDJSON_HAS_CXX17 +#define RAPIDJSON_HAS_CXX17 (RAPIDJSON_CPLUSPLUS >= 201703L) +#endif + +#if RAPIDJSON_HAS_CXX17 +# define RAPIDJSON_DELIBERATE_FALLTHROUGH [[fallthrough]] +#elif defined(__has_cpp_attribute) +# if __has_cpp_attribute(clang::fallthrough) +# define RAPIDJSON_DELIBERATE_FALLTHROUGH [[clang::fallthrough]] +# elif __has_cpp_attribute(fallthrough) +# define RAPIDJSON_DELIBERATE_FALLTHROUGH __attribute__((fallthrough)) +# else +# define RAPIDJSON_DELIBERATE_FALLTHROUGH +# endif +#else +# define RAPIDJSON_DELIBERATE_FALLTHROUGH +#endif + +//!@endcond + +//! Assertion (in non-throwing contexts). + /*! \ingroup RAPIDJSON_CONFIG + Some functions provide a \c noexcept guarantee, if the compiler supports it. + In these cases, the \ref RAPIDJSON_ASSERT macro cannot be overridden to + throw an exception. This macro adds a separate customization point for + such cases. + + Defaults to C \c assert() (as \ref RAPIDJSON_ASSERT), if \c noexcept is + supported, and to \ref RAPIDJSON_ASSERT otherwise. + */ + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_NOEXCEPT_ASSERT + +#ifndef RAPIDJSON_NOEXCEPT_ASSERT +#ifdef RAPIDJSON_ASSERT_THROWS +#include +#define RAPIDJSON_NOEXCEPT_ASSERT(x) assert(x) +#else +#define RAPIDJSON_NOEXCEPT_ASSERT(x) RAPIDJSON_ASSERT(x) +#endif // RAPIDJSON_ASSERT_THROWS +#endif // RAPIDJSON_NOEXCEPT_ASSERT + +/////////////////////////////////////////////////////////////////////////////// +// malloc/realloc/free + +#ifndef RAPIDJSON_MALLOC +///! customization point for global \c malloc +#define RAPIDJSON_MALLOC(size) std::malloc(size) +#endif +#ifndef RAPIDJSON_REALLOC +///! customization point for global \c realloc +#define RAPIDJSON_REALLOC(ptr, new_size) std::realloc(ptr, new_size) +#endif +#ifndef RAPIDJSON_FREE +///! customization point for global \c free +#define RAPIDJSON_FREE(ptr) std::free(ptr) +#endif + +/////////////////////////////////////////////////////////////////////////////// +// new/delete + +#ifndef RAPIDJSON_NEW +///! customization point for global \c new +#define RAPIDJSON_NEW(TypeName) new TypeName +#endif +#ifndef RAPIDJSON_DELETE +///! customization point for global \c delete +#define RAPIDJSON_DELETE(x) delete x +#endif + +/////////////////////////////////////////////////////////////////////////////// +// Type + +/*! \namespace rapidjson + \brief main RapidJSON namespace + \see RAPIDJSON_NAMESPACE +*/ +RAPIDJSON_NAMESPACE_BEGIN + +//! Type of JSON value +enum Type { + kNullType = 0, //!< null + kFalseType = 1, //!< false + kTrueType = 2, //!< true + kObjectType = 3, //!< object + kArrayType = 4, //!< array + kStringType = 5, //!< string + kNumberType = 6 //!< number +}; + +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_RAPIDJSON_H_ diff --git a/include/rapidjson/reader.h b/include/rapidjson/reader.h new file mode 100644 index 0000000000..f7ef610244 --- /dev/null +++ b/include/rapidjson/reader.h @@ -0,0 +1,2246 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_READER_H_ +#define RAPIDJSON_READER_H_ + +/*! \file reader.h */ + +#include "allocators.h" +#include "stream.h" +#include "encodedstream.h" +#include "internal/clzll.h" +#include "internal/meta.h" +#include "internal/stack.h" +#include "internal/strtod.h" +#include + +#if defined(RAPIDJSON_SIMD) && defined(_MSC_VER) +#include +#pragma intrinsic(_BitScanForward) +#endif +#ifdef RAPIDJSON_SSE42 +#include +#elif defined(RAPIDJSON_SSE2) +#include +#elif defined(RAPIDJSON_NEON) +#include +#endif + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(old-style-cast) +RAPIDJSON_DIAG_OFF(padded) +RAPIDJSON_DIAG_OFF(switch-enum) +#elif defined(_MSC_VER) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(4127) // conditional expression is constant +RAPIDJSON_DIAG_OFF(4702) // unreachable code +#endif + +#ifdef __GNUC__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(effc++) +#endif + +//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN +#define RAPIDJSON_NOTHING /* deliberately empty */ +#ifndef RAPIDJSON_PARSE_ERROR_EARLY_RETURN +#define RAPIDJSON_PARSE_ERROR_EARLY_RETURN(value) \ + RAPIDJSON_MULTILINEMACRO_BEGIN \ + if (RAPIDJSON_UNLIKELY(HasParseError())) { return value; } \ + RAPIDJSON_MULTILINEMACRO_END +#endif +#define RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID \ + RAPIDJSON_PARSE_ERROR_EARLY_RETURN(RAPIDJSON_NOTHING) +//!@endcond + +/*! \def RAPIDJSON_PARSE_ERROR_NORETURN + \ingroup RAPIDJSON_ERRORS + \brief Macro to indicate a parse error. + \param parseErrorCode \ref rapidjson::ParseErrorCode of the error + \param offset position of the error in JSON input (\c size_t) + + This macros can be used as a customization point for the internal + error handling mechanism of RapidJSON. + + A common usage model is to throw an exception instead of requiring the + caller to explicitly check the \ref rapidjson::GenericReader::Parse's + return value: + + \code + #define RAPIDJSON_PARSE_ERROR_NORETURN(parseErrorCode,offset) \ + throw ParseException(parseErrorCode, #parseErrorCode, offset) + + #include // std::runtime_error + #include "rapidjson/error/error.h" // rapidjson::ParseResult + + struct ParseException : std::runtime_error, rapidjson::ParseResult { + ParseException(rapidjson::ParseErrorCode code, const char* msg, size_t offset) + : std::runtime_error(msg), ParseResult(code, offset) {} + }; + + #include "rapidjson/reader.h" + \endcode + + \see RAPIDJSON_PARSE_ERROR, rapidjson::GenericReader::Parse + */ +#ifndef RAPIDJSON_PARSE_ERROR_NORETURN +#define RAPIDJSON_PARSE_ERROR_NORETURN(parseErrorCode, offset) \ + RAPIDJSON_MULTILINEMACRO_BEGIN \ + RAPIDJSON_ASSERT(!HasParseError()); /* Error can only be assigned once */ \ + SetParseError(parseErrorCode, offset); \ + RAPIDJSON_MULTILINEMACRO_END +#endif + +/*! \def RAPIDJSON_PARSE_ERROR + \ingroup RAPIDJSON_ERRORS + \brief (Internal) macro to indicate and handle a parse error. + \param parseErrorCode \ref rapidjson::ParseErrorCode of the error + \param offset position of the error in JSON input (\c size_t) + + Invokes RAPIDJSON_PARSE_ERROR_NORETURN and stops the parsing. + + \see RAPIDJSON_PARSE_ERROR_NORETURN + \hideinitializer + */ +#ifndef RAPIDJSON_PARSE_ERROR +#define RAPIDJSON_PARSE_ERROR(parseErrorCode, offset) \ + RAPIDJSON_MULTILINEMACRO_BEGIN \ + RAPIDJSON_PARSE_ERROR_NORETURN(parseErrorCode, offset); \ + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; \ + RAPIDJSON_MULTILINEMACRO_END +#endif + +#include "error/error.h" // ParseErrorCode, ParseResult + +RAPIDJSON_NAMESPACE_BEGIN + +/////////////////////////////////////////////////////////////////////////////// +// ParseFlag + +/*! \def RAPIDJSON_PARSE_DEFAULT_FLAGS + \ingroup RAPIDJSON_CONFIG + \brief User-defined kParseDefaultFlags definition. + + User can define this as any \c ParseFlag combinations. +*/ +#ifndef RAPIDJSON_PARSE_DEFAULT_FLAGS +#define RAPIDJSON_PARSE_DEFAULT_FLAGS kParseNoFlags +#endif + +//! Combination of parseFlags +/*! \see Reader::Parse, Document::Parse, Document::ParseInsitu, Document::ParseStream + */ +enum ParseFlag { + kParseNoFlags = 0, //!< No flags are set. + kParseInsituFlag = 1, //!< In-situ(destructive) parsing. + kParseValidateEncodingFlag = 2, //!< Validate encoding of JSON strings. + kParseIterativeFlag = 4, //!< Iterative(constant complexity in terms of function call stack size) parsing. + kParseStopWhenDoneFlag = 8, //!< After parsing a complete JSON root from stream, stop further processing the rest of stream. When this flag is used, parser will not generate kParseErrorDocumentRootNotSingular error. + kParseFullPrecisionFlag = 16, //!< Parse number in full precision (but slower). + kParseCommentsFlag = 32, //!< Allow one-line (//) and multi-line (/**/) comments. + kParseNumbersAsStringsFlag = 64, //!< Parse all numbers (ints/doubles) as strings. + kParseTrailingCommasFlag = 128, //!< Allow trailing commas at the end of objects and arrays. + kParseNanAndInfFlag = 256, //!< Allow parsing NaN, Inf, Infinity, -Inf and -Infinity as doubles. + kParseEscapedApostropheFlag = 512, //!< Allow escaped apostrophe in strings. + kParseDefaultFlags = RAPIDJSON_PARSE_DEFAULT_FLAGS //!< Default parse flags. Can be customized by defining RAPIDJSON_PARSE_DEFAULT_FLAGS +}; + +/////////////////////////////////////////////////////////////////////////////// +// Handler + +/*! \class rapidjson::Handler + \brief Concept for receiving events from GenericReader upon parsing. + The functions return true if no error occurs. If they return false, + the event publisher should terminate the process. +\code +concept Handler { + typename Ch; + + bool Null(); + bool Bool(bool b); + bool Int(int i); + bool Uint(unsigned i); + bool Int64(int64_t i); + bool Uint64(uint64_t i); + bool Double(double d); + /// enabled via kParseNumbersAsStringsFlag, string is not null-terminated (use length) + bool RawNumber(const Ch* str, SizeType length, bool copy); + bool String(const Ch* str, SizeType length, bool copy); + bool StartObject(); + bool Key(const Ch* str, SizeType length, bool copy); + bool EndObject(SizeType memberCount); + bool StartArray(); + bool EndArray(SizeType elementCount); +}; +\endcode +*/ +/////////////////////////////////////////////////////////////////////////////// +// BaseReaderHandler + +//! Default implementation of Handler. +/*! This can be used as base class of any reader handler. + \note implements Handler concept +*/ +template, typename Derived = void> +struct BaseReaderHandler { + typedef typename Encoding::Ch Ch; + + typedef typename internal::SelectIf, BaseReaderHandler, Derived>::Type Override; + + bool Default() { return true; } + bool Null() { return static_cast(*this).Default(); } + bool Bool(bool) { return static_cast(*this).Default(); } + bool Int(int) { return static_cast(*this).Default(); } + bool Uint(unsigned) { return static_cast(*this).Default(); } + bool Int64(int64_t) { return static_cast(*this).Default(); } + bool Uint64(uint64_t) { return static_cast(*this).Default(); } + bool Double(double) { return static_cast(*this).Default(); } + /// enabled via kParseNumbersAsStringsFlag, string is not null-terminated (use length) + bool RawNumber(const Ch* str, SizeType len, bool copy) { return static_cast(*this).String(str, len, copy); } + bool String(const Ch*, SizeType, bool) { return static_cast(*this).Default(); } + bool StartObject() { return static_cast(*this).Default(); } + bool Key(const Ch* str, SizeType len, bool copy) { return static_cast(*this).String(str, len, copy); } + bool EndObject(SizeType) { return static_cast(*this).Default(); } + bool StartArray() { return static_cast(*this).Default(); } + bool EndArray(SizeType) { return static_cast(*this).Default(); } +}; + +/////////////////////////////////////////////////////////////////////////////// +// StreamLocalCopy + +namespace internal { + +template::copyOptimization> +class StreamLocalCopy; + +//! Do copy optimization. +template +class StreamLocalCopy { +public: + StreamLocalCopy(Stream& original) : s(original), original_(original) {} + ~StreamLocalCopy() { original_ = s; } + + Stream s; + +private: + StreamLocalCopy& operator=(const StreamLocalCopy&) /* = delete */; + + Stream& original_; +}; + +//! Keep reference. +template +class StreamLocalCopy { +public: + StreamLocalCopy(Stream& original) : s(original) {} + + Stream& s; + +private: + StreamLocalCopy& operator=(const StreamLocalCopy&) /* = delete */; +}; + +} // namespace internal + +/////////////////////////////////////////////////////////////////////////////// +// SkipWhitespace + +//! Skip the JSON white spaces in a stream. +/*! \param is A input stream for skipping white spaces. + \note This function has SSE2/SSE4.2 specialization. +*/ +template +void SkipWhitespace(InputStream& is) { + internal::StreamLocalCopy copy(is); + InputStream& s(copy.s); + + typename InputStream::Ch c; + while ((c = s.Peek()) == ' ' || c == '\n' || c == '\r' || c == '\t') + s.Take(); +} + +inline const char* SkipWhitespace(const char* p, const char* end) { + while (p != end && (*p == ' ' || *p == '\n' || *p == '\r' || *p == '\t')) + ++p; + return p; +} + +#ifdef RAPIDJSON_SSE42 +//! Skip whitespace with SSE 4.2 pcmpistrm instruction, testing 16 8-byte characters at once. +inline const char *SkipWhitespace_SIMD(const char* p) { + // Fast return for single non-whitespace + if (*p == ' ' || *p == '\n' || *p == '\r' || *p == '\t') + ++p; + else + return p; + + // 16-byte align to the next boundary + const char* nextAligned = reinterpret_cast((reinterpret_cast(p) + 15) & static_cast(~15)); + while (p != nextAligned) + if (*p == ' ' || *p == '\n' || *p == '\r' || *p == '\t') + ++p; + else + return p; + + // The rest of string using SIMD + static const char whitespace[16] = " \n\r\t"; + const __m128i w = _mm_loadu_si128(reinterpret_cast(&whitespace[0])); + + for (;; p += 16) { + const __m128i s = _mm_load_si128(reinterpret_cast(p)); + const int r = _mm_cmpistri(w, s, _SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_LEAST_SIGNIFICANT | _SIDD_NEGATIVE_POLARITY); + if (r != 16) // some of characters is non-whitespace + return p + r; + } +} + +inline const char *SkipWhitespace_SIMD(const char* p, const char* end) { + // Fast return for single non-whitespace + if (p != end && (*p == ' ' || *p == '\n' || *p == '\r' || *p == '\t')) + ++p; + else + return p; + + // The middle of string using SIMD + static const char whitespace[16] = " \n\r\t"; + const __m128i w = _mm_loadu_si128(reinterpret_cast(&whitespace[0])); + + for (; p <= end - 16; p += 16) { + const __m128i s = _mm_loadu_si128(reinterpret_cast(p)); + const int r = _mm_cmpistri(w, s, _SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_LEAST_SIGNIFICANT | _SIDD_NEGATIVE_POLARITY); + if (r != 16) // some of characters is non-whitespace + return p + r; + } + + return SkipWhitespace(p, end); +} + +#elif defined(RAPIDJSON_SSE2) + +//! Skip whitespace with SSE2 instructions, testing 16 8-byte characters at once. +inline const char *SkipWhitespace_SIMD(const char* p) { + // Fast return for single non-whitespace + if (*p == ' ' || *p == '\n' || *p == '\r' || *p == '\t') + ++p; + else + return p; + + // 16-byte align to the next boundary + const char* nextAligned = reinterpret_cast((reinterpret_cast(p) + 15) & static_cast(~15)); + while (p != nextAligned) + if (*p == ' ' || *p == '\n' || *p == '\r' || *p == '\t') + ++p; + else + return p; + + // The rest of string + #define C16(c) { c, c, c, c, c, c, c, c, c, c, c, c, c, c, c, c } + static const char whitespaces[4][16] = { C16(' '), C16('\n'), C16('\r'), C16('\t') }; + #undef C16 + + const __m128i w0 = _mm_loadu_si128(reinterpret_cast(&whitespaces[0][0])); + const __m128i w1 = _mm_loadu_si128(reinterpret_cast(&whitespaces[1][0])); + const __m128i w2 = _mm_loadu_si128(reinterpret_cast(&whitespaces[2][0])); + const __m128i w3 = _mm_loadu_si128(reinterpret_cast(&whitespaces[3][0])); + + for (;; p += 16) { + const __m128i s = _mm_load_si128(reinterpret_cast(p)); + __m128i x = _mm_cmpeq_epi8(s, w0); + x = _mm_or_si128(x, _mm_cmpeq_epi8(s, w1)); + x = _mm_or_si128(x, _mm_cmpeq_epi8(s, w2)); + x = _mm_or_si128(x, _mm_cmpeq_epi8(s, w3)); + unsigned short r = static_cast(~_mm_movemask_epi8(x)); + if (r != 0) { // some of characters may be non-whitespace +#ifdef _MSC_VER // Find the index of first non-whitespace + unsigned long offset; + _BitScanForward(&offset, r); + return p + offset; +#else + return p + __builtin_ffs(r) - 1; +#endif + } + } +} + +inline const char *SkipWhitespace_SIMD(const char* p, const char* end) { + // Fast return for single non-whitespace + if (p != end && (*p == ' ' || *p == '\n' || *p == '\r' || *p == '\t')) + ++p; + else + return p; + + // The rest of string + #define C16(c) { c, c, c, c, c, c, c, c, c, c, c, c, c, c, c, c } + static const char whitespaces[4][16] = { C16(' '), C16('\n'), C16('\r'), C16('\t') }; + #undef C16 + + const __m128i w0 = _mm_loadu_si128(reinterpret_cast(&whitespaces[0][0])); + const __m128i w1 = _mm_loadu_si128(reinterpret_cast(&whitespaces[1][0])); + const __m128i w2 = _mm_loadu_si128(reinterpret_cast(&whitespaces[2][0])); + const __m128i w3 = _mm_loadu_si128(reinterpret_cast(&whitespaces[3][0])); + + for (; p <= end - 16; p += 16) { + const __m128i s = _mm_loadu_si128(reinterpret_cast(p)); + __m128i x = _mm_cmpeq_epi8(s, w0); + x = _mm_or_si128(x, _mm_cmpeq_epi8(s, w1)); + x = _mm_or_si128(x, _mm_cmpeq_epi8(s, w2)); + x = _mm_or_si128(x, _mm_cmpeq_epi8(s, w3)); + unsigned short r = static_cast(~_mm_movemask_epi8(x)); + if (r != 0) { // some of characters may be non-whitespace +#ifdef _MSC_VER // Find the index of first non-whitespace + unsigned long offset; + _BitScanForward(&offset, r); + return p + offset; +#else + return p + __builtin_ffs(r) - 1; +#endif + } + } + + return SkipWhitespace(p, end); +} + +#elif defined(RAPIDJSON_NEON) + +//! Skip whitespace with ARM Neon instructions, testing 16 8-byte characters at once. +inline const char *SkipWhitespace_SIMD(const char* p) { + // Fast return for single non-whitespace + if (*p == ' ' || *p == '\n' || *p == '\r' || *p == '\t') + ++p; + else + return p; + + // 16-byte align to the next boundary + const char* nextAligned = reinterpret_cast((reinterpret_cast(p) + 15) & static_cast(~15)); + while (p != nextAligned) + if (*p == ' ' || *p == '\n' || *p == '\r' || *p == '\t') + ++p; + else + return p; + + const uint8x16_t w0 = vmovq_n_u8(' '); + const uint8x16_t w1 = vmovq_n_u8('\n'); + const uint8x16_t w2 = vmovq_n_u8('\r'); + const uint8x16_t w3 = vmovq_n_u8('\t'); + + for (;; p += 16) { + const uint8x16_t s = vld1q_u8(reinterpret_cast(p)); + uint8x16_t x = vceqq_u8(s, w0); + x = vorrq_u8(x, vceqq_u8(s, w1)); + x = vorrq_u8(x, vceqq_u8(s, w2)); + x = vorrq_u8(x, vceqq_u8(s, w3)); + + x = vmvnq_u8(x); // Negate + x = vrev64q_u8(x); // Rev in 64 + uint64_t low = vgetq_lane_u64(vreinterpretq_u64_u8(x), 0); // extract + uint64_t high = vgetq_lane_u64(vreinterpretq_u64_u8(x), 1); // extract + + if (low == 0) { + if (high != 0) { + uint32_t lz = internal::clzll(high); + return p + 8 + (lz >> 3); + } + } else { + uint32_t lz = internal::clzll(low); + return p + (lz >> 3); + } + } +} + +inline const char *SkipWhitespace_SIMD(const char* p, const char* end) { + // Fast return for single non-whitespace + if (p != end && (*p == ' ' || *p == '\n' || *p == '\r' || *p == '\t')) + ++p; + else + return p; + + const uint8x16_t w0 = vmovq_n_u8(' '); + const uint8x16_t w1 = vmovq_n_u8('\n'); + const uint8x16_t w2 = vmovq_n_u8('\r'); + const uint8x16_t w3 = vmovq_n_u8('\t'); + + for (; p <= end - 16; p += 16) { + const uint8x16_t s = vld1q_u8(reinterpret_cast(p)); + uint8x16_t x = vceqq_u8(s, w0); + x = vorrq_u8(x, vceqq_u8(s, w1)); + x = vorrq_u8(x, vceqq_u8(s, w2)); + x = vorrq_u8(x, vceqq_u8(s, w3)); + + x = vmvnq_u8(x); // Negate + x = vrev64q_u8(x); // Rev in 64 + uint64_t low = vgetq_lane_u64(vreinterpretq_u64_u8(x), 0); // extract + uint64_t high = vgetq_lane_u64(vreinterpretq_u64_u8(x), 1); // extract + + if (low == 0) { + if (high != 0) { + uint32_t lz = internal::clzll(high); + return p + 8 + (lz >> 3); + } + } else { + uint32_t lz = internal::clzll(low); + return p + (lz >> 3); + } + } + + return SkipWhitespace(p, end); +} + +#endif // RAPIDJSON_NEON + +#ifdef RAPIDJSON_SIMD +//! Template function specialization for InsituStringStream +template<> inline void SkipWhitespace(InsituStringStream& is) { + is.src_ = const_cast(SkipWhitespace_SIMD(is.src_)); +} + +//! Template function specialization for StringStream +template<> inline void SkipWhitespace(StringStream& is) { + is.src_ = SkipWhitespace_SIMD(is.src_); +} + +template<> inline void SkipWhitespace(EncodedInputStream, MemoryStream>& is) { + is.is_.src_ = SkipWhitespace_SIMD(is.is_.src_, is.is_.end_); +} +#endif // RAPIDJSON_SIMD + +/////////////////////////////////////////////////////////////////////////////// +// GenericReader + +//! SAX-style JSON parser. Use \ref Reader for UTF8 encoding and default allocator. +/*! GenericReader parses JSON text from a stream, and send events synchronously to an + object implementing Handler concept. + + It needs to allocate a stack for storing a single decoded string during + non-destructive parsing. + + For in-situ parsing, the decoded string is directly written to the source + text string, no temporary buffer is required. + + A GenericReader object can be reused for parsing multiple JSON text. + + \tparam SourceEncoding Encoding of the input stream. + \tparam TargetEncoding Encoding of the parse output. + \tparam StackAllocator Allocator type for stack. +*/ +template +class GenericReader { +public: + typedef typename SourceEncoding::Ch Ch; //!< SourceEncoding character type + + //! Constructor. + /*! \param stackAllocator Optional allocator for allocating stack memory. (Only use for non-destructive parsing) + \param stackCapacity stack capacity in bytes for storing a single decoded string. (Only use for non-destructive parsing) + */ + GenericReader(StackAllocator* stackAllocator = 0, size_t stackCapacity = kDefaultStackCapacity) : + stack_(stackAllocator, stackCapacity), parseResult_(), state_(IterativeParsingStartState) {} + + //! Parse JSON text. + /*! \tparam parseFlags Combination of \ref ParseFlag. + \tparam InputStream Type of input stream, implementing Stream concept. + \tparam Handler Type of handler, implementing Handler concept. + \param is Input stream to be parsed. + \param handler The handler to receive events. + \return Whether the parsing is successful. + */ + template + ParseResult Parse(InputStream& is, Handler& handler) { + if (parseFlags & kParseIterativeFlag) + return IterativeParse(is, handler); + + parseResult_.Clear(); + + ClearStackOnExit scope(*this); + + SkipWhitespaceAndComments(is); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN(parseResult_); + + if (RAPIDJSON_UNLIKELY(is.Peek() == '\0')) { + RAPIDJSON_PARSE_ERROR_NORETURN(kParseErrorDocumentEmpty, is.Tell()); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN(parseResult_); + } + else { + ParseValue(is, handler); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN(parseResult_); + + if (!(parseFlags & kParseStopWhenDoneFlag)) { + SkipWhitespaceAndComments(is); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN(parseResult_); + + if (RAPIDJSON_UNLIKELY(is.Peek() != '\0')) { + RAPIDJSON_PARSE_ERROR_NORETURN(kParseErrorDocumentRootNotSingular, is.Tell()); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN(parseResult_); + } + } + } + + return parseResult_; + } + + //! Parse JSON text (with \ref kParseDefaultFlags) + /*! \tparam InputStream Type of input stream, implementing Stream concept + \tparam Handler Type of handler, implementing Handler concept. + \param is Input stream to be parsed. + \param handler The handler to receive events. + \return Whether the parsing is successful. + */ + template + ParseResult Parse(InputStream& is, Handler& handler) { + return Parse(is, handler); + } + + //! Initialize JSON text token-by-token parsing + /*! + */ + void IterativeParseInit() { + parseResult_.Clear(); + state_ = IterativeParsingStartState; + } + + //! Parse one token from JSON text + /*! \tparam InputStream Type of input stream, implementing Stream concept + \tparam Handler Type of handler, implementing Handler concept. + \param is Input stream to be parsed. + \param handler The handler to receive events. + \return Whether the parsing is successful. + */ + template + bool IterativeParseNext(InputStream& is, Handler& handler) { + while (RAPIDJSON_LIKELY(is.Peek() != '\0')) { + SkipWhitespaceAndComments(is); + + Token t = Tokenize(is.Peek()); + IterativeParsingState n = Predict(state_, t); + IterativeParsingState d = Transit(state_, t, n, is, handler); + + // If we've finished or hit an error... + if (RAPIDJSON_UNLIKELY(IsIterativeParsingCompleteState(d))) { + // Report errors. + if (d == IterativeParsingErrorState) { + HandleError(state_, is); + return false; + } + + // Transition to the finish state. + RAPIDJSON_ASSERT(d == IterativeParsingFinishState); + state_ = d; + + // If StopWhenDone is not set... + if (!(parseFlags & kParseStopWhenDoneFlag)) { + // ... and extra non-whitespace data is found... + SkipWhitespaceAndComments(is); + if (is.Peek() != '\0') { + // ... this is considered an error. + HandleError(state_, is); + return false; + } + } + + // Success! We are done! + return true; + } + + // Transition to the new state. + state_ = d; + + // If we parsed anything other than a delimiter, we invoked the handler, so we can return true now. + if (!IsIterativeParsingDelimiterState(n)) + return true; + } + + // We reached the end of file. + stack_.Clear(); + + if (state_ != IterativeParsingFinishState) { + HandleError(state_, is); + return false; + } + + return true; + } + + //! Check if token-by-token parsing JSON text is complete + /*! \return Whether the JSON has been fully decoded. + */ + RAPIDJSON_FORCEINLINE bool IterativeParseComplete() const { + return IsIterativeParsingCompleteState(state_); + } + + //! Whether a parse error has occurred in the last parsing. + bool HasParseError() const { return parseResult_.IsError(); } + + //! Get the \ref ParseErrorCode of last parsing. + ParseErrorCode GetParseErrorCode() const { return parseResult_.Code(); } + + //! Get the position of last parsing error in input, 0 otherwise. + size_t GetErrorOffset() const { return parseResult_.Offset(); } + +protected: + void SetParseError(ParseErrorCode code, size_t offset) { parseResult_.Set(code, offset); } + +private: + // Prohibit copy constructor & assignment operator. + GenericReader(const GenericReader&); + GenericReader& operator=(const GenericReader&); + + void ClearStack() { stack_.Clear(); } + + // clear stack on any exit from ParseStream, e.g. due to exception + struct ClearStackOnExit { + explicit ClearStackOnExit(GenericReader& r) : r_(r) {} + ~ClearStackOnExit() { r_.ClearStack(); } + private: + GenericReader& r_; + ClearStackOnExit(const ClearStackOnExit&); + ClearStackOnExit& operator=(const ClearStackOnExit&); + }; + + template + void SkipWhitespaceAndComments(InputStream& is) { + SkipWhitespace(is); + + if (parseFlags & kParseCommentsFlag) { + while (RAPIDJSON_UNLIKELY(Consume(is, '/'))) { + if (Consume(is, '*')) { + while (true) { + if (RAPIDJSON_UNLIKELY(is.Peek() == '\0')) + RAPIDJSON_PARSE_ERROR(kParseErrorUnspecificSyntaxError, is.Tell()); + else if (Consume(is, '*')) { + if (Consume(is, '/')) + break; + } + else + is.Take(); + } + } + else if (RAPIDJSON_LIKELY(Consume(is, '/'))) + while (is.Peek() != '\0' && is.Take() != '\n') {} + else + RAPIDJSON_PARSE_ERROR(kParseErrorUnspecificSyntaxError, is.Tell()); + + SkipWhitespace(is); + } + } + } + + // Parse object: { string : value, ... } + template + void ParseObject(InputStream& is, Handler& handler) { + RAPIDJSON_ASSERT(is.Peek() == '{'); + is.Take(); // Skip '{' + + if (RAPIDJSON_UNLIKELY(!handler.StartObject())) + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, is.Tell()); + + SkipWhitespaceAndComments(is); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + + if (Consume(is, '}')) { + if (RAPIDJSON_UNLIKELY(!handler.EndObject(0))) // empty object + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, is.Tell()); + return; + } + + for (SizeType memberCount = 0;;) { + if (RAPIDJSON_UNLIKELY(is.Peek() != '"')) + RAPIDJSON_PARSE_ERROR(kParseErrorObjectMissName, is.Tell()); + + ParseString(is, handler, true); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + + SkipWhitespaceAndComments(is); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + + if (RAPIDJSON_UNLIKELY(!Consume(is, ':'))) + RAPIDJSON_PARSE_ERROR(kParseErrorObjectMissColon, is.Tell()); + + SkipWhitespaceAndComments(is); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + + ParseValue(is, handler); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + + SkipWhitespaceAndComments(is); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + + ++memberCount; + + switch (is.Peek()) { + case ',': + is.Take(); + SkipWhitespaceAndComments(is); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + break; + case '}': + is.Take(); + if (RAPIDJSON_UNLIKELY(!handler.EndObject(memberCount))) + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, is.Tell()); + return; + default: + RAPIDJSON_PARSE_ERROR(kParseErrorObjectMissCommaOrCurlyBracket, is.Tell()); break; // This useless break is only for making warning and coverage happy + } + + if (parseFlags & kParseTrailingCommasFlag) { + if (is.Peek() == '}') { + if (RAPIDJSON_UNLIKELY(!handler.EndObject(memberCount))) + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, is.Tell()); + is.Take(); + return; + } + } + } + } + + // Parse array: [ value, ... ] + template + void ParseArray(InputStream& is, Handler& handler) { + RAPIDJSON_ASSERT(is.Peek() == '['); + is.Take(); // Skip '[' + + if (RAPIDJSON_UNLIKELY(!handler.StartArray())) + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, is.Tell()); + + SkipWhitespaceAndComments(is); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + + if (Consume(is, ']')) { + if (RAPIDJSON_UNLIKELY(!handler.EndArray(0))) // empty array + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, is.Tell()); + return; + } + + for (SizeType elementCount = 0;;) { + ParseValue(is, handler); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + + ++elementCount; + SkipWhitespaceAndComments(is); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + + if (Consume(is, ',')) { + SkipWhitespaceAndComments(is); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + } + else if (Consume(is, ']')) { + if (RAPIDJSON_UNLIKELY(!handler.EndArray(elementCount))) + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, is.Tell()); + return; + } + else + RAPIDJSON_PARSE_ERROR(kParseErrorArrayMissCommaOrSquareBracket, is.Tell()); + + if (parseFlags & kParseTrailingCommasFlag) { + if (is.Peek() == ']') { + if (RAPIDJSON_UNLIKELY(!handler.EndArray(elementCount))) + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, is.Tell()); + is.Take(); + return; + } + } + } + } + + template + void ParseNull(InputStream& is, Handler& handler) { + RAPIDJSON_ASSERT(is.Peek() == 'n'); + is.Take(); + + if (RAPIDJSON_LIKELY(Consume(is, 'u') && Consume(is, 'l') && Consume(is, 'l'))) { + if (RAPIDJSON_UNLIKELY(!handler.Null())) + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, is.Tell()); + } + else + RAPIDJSON_PARSE_ERROR(kParseErrorValueInvalid, is.Tell()); + } + + template + void ParseTrue(InputStream& is, Handler& handler) { + RAPIDJSON_ASSERT(is.Peek() == 't'); + is.Take(); + + if (RAPIDJSON_LIKELY(Consume(is, 'r') && Consume(is, 'u') && Consume(is, 'e'))) { + if (RAPIDJSON_UNLIKELY(!handler.Bool(true))) + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, is.Tell()); + } + else + RAPIDJSON_PARSE_ERROR(kParseErrorValueInvalid, is.Tell()); + } + + template + void ParseFalse(InputStream& is, Handler& handler) { + RAPIDJSON_ASSERT(is.Peek() == 'f'); + is.Take(); + + if (RAPIDJSON_LIKELY(Consume(is, 'a') && Consume(is, 'l') && Consume(is, 's') && Consume(is, 'e'))) { + if (RAPIDJSON_UNLIKELY(!handler.Bool(false))) + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, is.Tell()); + } + else + RAPIDJSON_PARSE_ERROR(kParseErrorValueInvalid, is.Tell()); + } + + template + RAPIDJSON_FORCEINLINE static bool Consume(InputStream& is, typename InputStream::Ch expect) { + if (RAPIDJSON_LIKELY(is.Peek() == expect)) { + is.Take(); + return true; + } + else + return false; + } + + // Helper function to parse four hexadecimal digits in \uXXXX in ParseString(). + template + unsigned ParseHex4(InputStream& is, size_t escapeOffset) { + unsigned codepoint = 0; + for (int i = 0; i < 4; i++) { + Ch c = is.Peek(); + codepoint <<= 4; + codepoint += static_cast(c); + if (c >= '0' && c <= '9') + codepoint -= '0'; + else if (c >= 'A' && c <= 'F') + codepoint -= 'A' - 10; + else if (c >= 'a' && c <= 'f') + codepoint -= 'a' - 10; + else { + RAPIDJSON_PARSE_ERROR_NORETURN(kParseErrorStringUnicodeEscapeInvalidHex, escapeOffset); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN(0); + } + is.Take(); + } + return codepoint; + } + + template + class StackStream { + public: + typedef CharType Ch; + + StackStream(internal::Stack& stack) : stack_(stack), length_(0) {} + RAPIDJSON_FORCEINLINE void Put(Ch c) { + *stack_.template Push() = c; + ++length_; + } + + RAPIDJSON_FORCEINLINE void* Push(SizeType count) { + length_ += count; + return stack_.template Push(count); + } + + size_t Length() const { return length_; } + + Ch* Pop() { + return stack_.template Pop(length_); + } + + private: + StackStream(const StackStream&); + StackStream& operator=(const StackStream&); + + internal::Stack& stack_; + SizeType length_; + }; + + // Parse string and generate String event. Different code paths for kParseInsituFlag. + template + void ParseString(InputStream& is, Handler& handler, bool isKey = false) { + internal::StreamLocalCopy copy(is); + InputStream& s(copy.s); + + RAPIDJSON_ASSERT(s.Peek() == '\"'); + s.Take(); // Skip '\"' + + bool success = false; + if (parseFlags & kParseInsituFlag) { + typename InputStream::Ch *head = s.PutBegin(); + ParseStringToStream(s, s); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + size_t length = s.PutEnd(head) - 1; + RAPIDJSON_ASSERT(length <= 0xFFFFFFFF); + const typename TargetEncoding::Ch* const str = reinterpret_cast(head); + success = (isKey ? handler.Key(str, SizeType(length), false) : handler.String(str, SizeType(length), false)); + } + else { + StackStream stackStream(stack_); + ParseStringToStream(s, stackStream); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + SizeType length = static_cast(stackStream.Length()) - 1; + const typename TargetEncoding::Ch* const str = stackStream.Pop(); + success = (isKey ? handler.Key(str, length, true) : handler.String(str, length, true)); + } + if (RAPIDJSON_UNLIKELY(!success)) + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, s.Tell()); + } + + // Parse string to an output is + // This function handles the prefix/suffix double quotes, escaping, and optional encoding validation. + template + RAPIDJSON_FORCEINLINE void ParseStringToStream(InputStream& is, OutputStream& os) { +//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN +#define Z16 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 + static const char escape[256] = { + Z16, Z16, 0, 0,'\"', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, '/', + Z16, Z16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,'\\', 0, 0, 0, + 0, 0,'\b', 0, 0, 0,'\f', 0, 0, 0, 0, 0, 0, 0,'\n', 0, + 0, 0,'\r', 0,'\t', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + Z16, Z16, Z16, Z16, Z16, Z16, Z16, Z16 + }; +#undef Z16 +//!@endcond + + for (;;) { + // Scan and copy string before "\\\"" or < 0x20. This is an optional optimzation. + if (!(parseFlags & kParseValidateEncodingFlag)) + ScanCopyUnescapedString(is, os); + + Ch c = is.Peek(); + if (RAPIDJSON_UNLIKELY(c == '\\')) { // Escape + size_t escapeOffset = is.Tell(); // For invalid escaping, report the initial '\\' as error offset + is.Take(); + Ch e = is.Peek(); + if ((sizeof(Ch) == 1 || unsigned(e) < 256) && RAPIDJSON_LIKELY(escape[static_cast(e)])) { + is.Take(); + os.Put(static_cast(escape[static_cast(e)])); + } + else if ((parseFlags & kParseEscapedApostropheFlag) && RAPIDJSON_LIKELY(e == '\'')) { // Allow escaped apostrophe + is.Take(); + os.Put('\''); + } + else if (RAPIDJSON_LIKELY(e == 'u')) { // Unicode + is.Take(); + unsigned codepoint = ParseHex4(is, escapeOffset); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + if (RAPIDJSON_UNLIKELY(codepoint >= 0xD800 && codepoint <= 0xDFFF)) { + // high surrogate, check if followed by valid low surrogate + if (RAPIDJSON_LIKELY(codepoint <= 0xDBFF)) { + // Handle UTF-16 surrogate pair + if (RAPIDJSON_UNLIKELY(!Consume(is, '\\') || !Consume(is, 'u'))) + RAPIDJSON_PARSE_ERROR(kParseErrorStringUnicodeSurrogateInvalid, escapeOffset); + unsigned codepoint2 = ParseHex4(is, escapeOffset); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN_VOID; + if (RAPIDJSON_UNLIKELY(codepoint2 < 0xDC00 || codepoint2 > 0xDFFF)) + RAPIDJSON_PARSE_ERROR(kParseErrorStringUnicodeSurrogateInvalid, escapeOffset); + codepoint = (((codepoint - 0xD800) << 10) | (codepoint2 - 0xDC00)) + 0x10000; + } + // single low surrogate + else + { + RAPIDJSON_PARSE_ERROR(kParseErrorStringUnicodeSurrogateInvalid, escapeOffset); + } + } + TEncoding::Encode(os, codepoint); + } + else + RAPIDJSON_PARSE_ERROR(kParseErrorStringEscapeInvalid, escapeOffset); + } + else if (RAPIDJSON_UNLIKELY(c == '"')) { // Closing double quote + is.Take(); + os.Put('\0'); // null-terminate the string + return; + } + else if (RAPIDJSON_UNLIKELY(static_cast(c) < 0x20)) { // RFC 4627: unescaped = %x20-21 / %x23-5B / %x5D-10FFFF + if (c == '\0') + RAPIDJSON_PARSE_ERROR(kParseErrorStringMissQuotationMark, is.Tell()); + else + RAPIDJSON_PARSE_ERROR(kParseErrorStringInvalidEncoding, is.Tell()); + } + else { + size_t offset = is.Tell(); + if (RAPIDJSON_UNLIKELY((parseFlags & kParseValidateEncodingFlag ? + !Transcoder::Validate(is, os) : + !Transcoder::Transcode(is, os)))) + RAPIDJSON_PARSE_ERROR(kParseErrorStringInvalidEncoding, offset); + } + } + } + + template + static RAPIDJSON_FORCEINLINE void ScanCopyUnescapedString(InputStream&, OutputStream&) { + // Do nothing for generic version + } + +#if defined(RAPIDJSON_SSE2) || defined(RAPIDJSON_SSE42) + // StringStream -> StackStream + static RAPIDJSON_FORCEINLINE void ScanCopyUnescapedString(StringStream& is, StackStream& os) { + const char* p = is.src_; + + // Scan one by one until alignment (unaligned load may cross page boundary and cause crash) + const char* nextAligned = reinterpret_cast((reinterpret_cast(p) + 15) & static_cast(~15)); + while (p != nextAligned) + if (RAPIDJSON_UNLIKELY(*p == '\"') || RAPIDJSON_UNLIKELY(*p == '\\') || RAPIDJSON_UNLIKELY(static_cast(*p) < 0x20)) { + is.src_ = p; + return; + } + else + os.Put(*p++); + + // The rest of string using SIMD + static const char dquote[16] = { '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"' }; + static const char bslash[16] = { '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\' }; + static const char space[16] = { 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F }; + const __m128i dq = _mm_loadu_si128(reinterpret_cast(&dquote[0])); + const __m128i bs = _mm_loadu_si128(reinterpret_cast(&bslash[0])); + const __m128i sp = _mm_loadu_si128(reinterpret_cast(&space[0])); + + for (;; p += 16) { + const __m128i s = _mm_load_si128(reinterpret_cast(p)); + const __m128i t1 = _mm_cmpeq_epi8(s, dq); + const __m128i t2 = _mm_cmpeq_epi8(s, bs); + const __m128i t3 = _mm_cmpeq_epi8(_mm_max_epu8(s, sp), sp); // s < 0x20 <=> max(s, 0x1F) == 0x1F + const __m128i x = _mm_or_si128(_mm_or_si128(t1, t2), t3); + unsigned short r = static_cast(_mm_movemask_epi8(x)); + if (RAPIDJSON_UNLIKELY(r != 0)) { // some of characters is escaped + SizeType length; + #ifdef _MSC_VER // Find the index of first escaped + unsigned long offset; + _BitScanForward(&offset, r); + length = offset; + #else + length = static_cast(__builtin_ffs(r) - 1); + #endif + if (length != 0) { + char* q = reinterpret_cast(os.Push(length)); + for (size_t i = 0; i < length; i++) + q[i] = p[i]; + + p += length; + } + break; + } + _mm_storeu_si128(reinterpret_cast<__m128i *>(os.Push(16)), s); + } + + is.src_ = p; + } + + // InsituStringStream -> InsituStringStream + static RAPIDJSON_FORCEINLINE void ScanCopyUnescapedString(InsituStringStream& is, InsituStringStream& os) { + RAPIDJSON_ASSERT(&is == &os); + (void)os; + + if (is.src_ == is.dst_) { + SkipUnescapedString(is); + return; + } + + char* p = is.src_; + char *q = is.dst_; + + // Scan one by one until alignment (unaligned load may cross page boundary and cause crash) + const char* nextAligned = reinterpret_cast((reinterpret_cast(p) + 15) & static_cast(~15)); + while (p != nextAligned) + if (RAPIDJSON_UNLIKELY(*p == '\"') || RAPIDJSON_UNLIKELY(*p == '\\') || RAPIDJSON_UNLIKELY(static_cast(*p) < 0x20)) { + is.src_ = p; + is.dst_ = q; + return; + } + else + *q++ = *p++; + + // The rest of string using SIMD + static const char dquote[16] = { '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"' }; + static const char bslash[16] = { '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\' }; + static const char space[16] = { 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F }; + const __m128i dq = _mm_loadu_si128(reinterpret_cast(&dquote[0])); + const __m128i bs = _mm_loadu_si128(reinterpret_cast(&bslash[0])); + const __m128i sp = _mm_loadu_si128(reinterpret_cast(&space[0])); + + for (;; p += 16, q += 16) { + const __m128i s = _mm_load_si128(reinterpret_cast(p)); + const __m128i t1 = _mm_cmpeq_epi8(s, dq); + const __m128i t2 = _mm_cmpeq_epi8(s, bs); + const __m128i t3 = _mm_cmpeq_epi8(_mm_max_epu8(s, sp), sp); // s < 0x20 <=> max(s, 0x1F) == 0x1F + const __m128i x = _mm_or_si128(_mm_or_si128(t1, t2), t3); + unsigned short r = static_cast(_mm_movemask_epi8(x)); + if (RAPIDJSON_UNLIKELY(r != 0)) { // some of characters is escaped + size_t length; +#ifdef _MSC_VER // Find the index of first escaped + unsigned long offset; + _BitScanForward(&offset, r); + length = offset; +#else + length = static_cast(__builtin_ffs(r) - 1); +#endif + for (const char* pend = p + length; p != pend; ) + *q++ = *p++; + break; + } + _mm_storeu_si128(reinterpret_cast<__m128i *>(q), s); + } + + is.src_ = p; + is.dst_ = q; + } + + // When read/write pointers are the same for insitu stream, just skip unescaped characters + static RAPIDJSON_FORCEINLINE void SkipUnescapedString(InsituStringStream& is) { + RAPIDJSON_ASSERT(is.src_ == is.dst_); + char* p = is.src_; + + // Scan one by one until alignment (unaligned load may cross page boundary and cause crash) + const char* nextAligned = reinterpret_cast((reinterpret_cast(p) + 15) & static_cast(~15)); + for (; p != nextAligned; p++) + if (RAPIDJSON_UNLIKELY(*p == '\"') || RAPIDJSON_UNLIKELY(*p == '\\') || RAPIDJSON_UNLIKELY(static_cast(*p) < 0x20)) { + is.src_ = is.dst_ = p; + return; + } + + // The rest of string using SIMD + static const char dquote[16] = { '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"' }; + static const char bslash[16] = { '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\' }; + static const char space[16] = { 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F }; + const __m128i dq = _mm_loadu_si128(reinterpret_cast(&dquote[0])); + const __m128i bs = _mm_loadu_si128(reinterpret_cast(&bslash[0])); + const __m128i sp = _mm_loadu_si128(reinterpret_cast(&space[0])); + + for (;; p += 16) { + const __m128i s = _mm_load_si128(reinterpret_cast(p)); + const __m128i t1 = _mm_cmpeq_epi8(s, dq); + const __m128i t2 = _mm_cmpeq_epi8(s, bs); + const __m128i t3 = _mm_cmpeq_epi8(_mm_max_epu8(s, sp), sp); // s < 0x20 <=> max(s, 0x1F) == 0x1F + const __m128i x = _mm_or_si128(_mm_or_si128(t1, t2), t3); + unsigned short r = static_cast(_mm_movemask_epi8(x)); + if (RAPIDJSON_UNLIKELY(r != 0)) { // some of characters is escaped + size_t length; +#ifdef _MSC_VER // Find the index of first escaped + unsigned long offset; + _BitScanForward(&offset, r); + length = offset; +#else + length = static_cast(__builtin_ffs(r) - 1); +#endif + p += length; + break; + } + } + + is.src_ = is.dst_ = p; + } +#elif defined(RAPIDJSON_NEON) + // StringStream -> StackStream + static RAPIDJSON_FORCEINLINE void ScanCopyUnescapedString(StringStream& is, StackStream& os) { + const char* p = is.src_; + + // Scan one by one until alignment (unaligned load may cross page boundary and cause crash) + const char* nextAligned = reinterpret_cast((reinterpret_cast(p) + 15) & static_cast(~15)); + while (p != nextAligned) + if (RAPIDJSON_UNLIKELY(*p == '\"') || RAPIDJSON_UNLIKELY(*p == '\\') || RAPIDJSON_UNLIKELY(static_cast(*p) < 0x20)) { + is.src_ = p; + return; + } + else + os.Put(*p++); + + // The rest of string using SIMD + const uint8x16_t s0 = vmovq_n_u8('"'); + const uint8x16_t s1 = vmovq_n_u8('\\'); + const uint8x16_t s2 = vmovq_n_u8('\b'); + const uint8x16_t s3 = vmovq_n_u8(32); + + for (;; p += 16) { + const uint8x16_t s = vld1q_u8(reinterpret_cast(p)); + uint8x16_t x = vceqq_u8(s, s0); + x = vorrq_u8(x, vceqq_u8(s, s1)); + x = vorrq_u8(x, vceqq_u8(s, s2)); + x = vorrq_u8(x, vcltq_u8(s, s3)); + + x = vrev64q_u8(x); // Rev in 64 + uint64_t low = vgetq_lane_u64(vreinterpretq_u64_u8(x), 0); // extract + uint64_t high = vgetq_lane_u64(vreinterpretq_u64_u8(x), 1); // extract + + SizeType length = 0; + bool escaped = false; + if (low == 0) { + if (high != 0) { + uint32_t lz = internal::clzll(high); + length = 8 + (lz >> 3); + escaped = true; + } + } else { + uint32_t lz = internal::clzll(low); + length = lz >> 3; + escaped = true; + } + if (RAPIDJSON_UNLIKELY(escaped)) { // some of characters is escaped + if (length != 0) { + char* q = reinterpret_cast(os.Push(length)); + for (size_t i = 0; i < length; i++) + q[i] = p[i]; + + p += length; + } + break; + } + vst1q_u8(reinterpret_cast(os.Push(16)), s); + } + + is.src_ = p; + } + + // InsituStringStream -> InsituStringStream + static RAPIDJSON_FORCEINLINE void ScanCopyUnescapedString(InsituStringStream& is, InsituStringStream& os) { + RAPIDJSON_ASSERT(&is == &os); + (void)os; + + if (is.src_ == is.dst_) { + SkipUnescapedString(is); + return; + } + + char* p = is.src_; + char *q = is.dst_; + + // Scan one by one until alignment (unaligned load may cross page boundary and cause crash) + const char* nextAligned = reinterpret_cast((reinterpret_cast(p) + 15) & static_cast(~15)); + while (p != nextAligned) + if (RAPIDJSON_UNLIKELY(*p == '\"') || RAPIDJSON_UNLIKELY(*p == '\\') || RAPIDJSON_UNLIKELY(static_cast(*p) < 0x20)) { + is.src_ = p; + is.dst_ = q; + return; + } + else + *q++ = *p++; + + // The rest of string using SIMD + const uint8x16_t s0 = vmovq_n_u8('"'); + const uint8x16_t s1 = vmovq_n_u8('\\'); + const uint8x16_t s2 = vmovq_n_u8('\b'); + const uint8x16_t s3 = vmovq_n_u8(32); + + for (;; p += 16, q += 16) { + const uint8x16_t s = vld1q_u8(reinterpret_cast(p)); + uint8x16_t x = vceqq_u8(s, s0); + x = vorrq_u8(x, vceqq_u8(s, s1)); + x = vorrq_u8(x, vceqq_u8(s, s2)); + x = vorrq_u8(x, vcltq_u8(s, s3)); + + x = vrev64q_u8(x); // Rev in 64 + uint64_t low = vgetq_lane_u64(vreinterpretq_u64_u8(x), 0); // extract + uint64_t high = vgetq_lane_u64(vreinterpretq_u64_u8(x), 1); // extract + + SizeType length = 0; + bool escaped = false; + if (low == 0) { + if (high != 0) { + uint32_t lz = internal::clzll(high); + length = 8 + (lz >> 3); + escaped = true; + } + } else { + uint32_t lz = internal::clzll(low); + length = lz >> 3; + escaped = true; + } + if (RAPIDJSON_UNLIKELY(escaped)) { // some of characters is escaped + for (const char* pend = p + length; p != pend; ) { + *q++ = *p++; + } + break; + } + vst1q_u8(reinterpret_cast(q), s); + } + + is.src_ = p; + is.dst_ = q; + } + + // When read/write pointers are the same for insitu stream, just skip unescaped characters + static RAPIDJSON_FORCEINLINE void SkipUnescapedString(InsituStringStream& is) { + RAPIDJSON_ASSERT(is.src_ == is.dst_); + char* p = is.src_; + + // Scan one by one until alignment (unaligned load may cross page boundary and cause crash) + const char* nextAligned = reinterpret_cast((reinterpret_cast(p) + 15) & static_cast(~15)); + for (; p != nextAligned; p++) + if (RAPIDJSON_UNLIKELY(*p == '\"') || RAPIDJSON_UNLIKELY(*p == '\\') || RAPIDJSON_UNLIKELY(static_cast(*p) < 0x20)) { + is.src_ = is.dst_ = p; + return; + } + + // The rest of string using SIMD + const uint8x16_t s0 = vmovq_n_u8('"'); + const uint8x16_t s1 = vmovq_n_u8('\\'); + const uint8x16_t s2 = vmovq_n_u8('\b'); + const uint8x16_t s3 = vmovq_n_u8(32); + + for (;; p += 16) { + const uint8x16_t s = vld1q_u8(reinterpret_cast(p)); + uint8x16_t x = vceqq_u8(s, s0); + x = vorrq_u8(x, vceqq_u8(s, s1)); + x = vorrq_u8(x, vceqq_u8(s, s2)); + x = vorrq_u8(x, vcltq_u8(s, s3)); + + x = vrev64q_u8(x); // Rev in 64 + uint64_t low = vgetq_lane_u64(vreinterpretq_u64_u8(x), 0); // extract + uint64_t high = vgetq_lane_u64(vreinterpretq_u64_u8(x), 1); // extract + + if (low == 0) { + if (high != 0) { + uint32_t lz = internal::clzll(high); + p += 8 + (lz >> 3); + break; + } + } else { + uint32_t lz = internal::clzll(low); + p += lz >> 3; + break; + } + } + + is.src_ = is.dst_ = p; + } +#endif // RAPIDJSON_NEON + + template + class NumberStream; + + template + class NumberStream { + public: + typedef typename InputStream::Ch Ch; + + NumberStream(GenericReader& reader, InputStream& s) : is(s) { (void)reader; } + + RAPIDJSON_FORCEINLINE Ch Peek() const { return is.Peek(); } + RAPIDJSON_FORCEINLINE Ch TakePush() { return is.Take(); } + RAPIDJSON_FORCEINLINE Ch Take() { return is.Take(); } + RAPIDJSON_FORCEINLINE void Push(char) {} + + size_t Tell() { return is.Tell(); } + size_t Length() { return 0; } + const StackCharacter* Pop() { return 0; } + + protected: + NumberStream& operator=(const NumberStream&); + + InputStream& is; + }; + + template + class NumberStream : public NumberStream { + typedef NumberStream Base; + public: + NumberStream(GenericReader& reader, InputStream& s) : Base(reader, s), stackStream(reader.stack_) {} + + RAPIDJSON_FORCEINLINE Ch TakePush() { + stackStream.Put(static_cast(Base::is.Peek())); + return Base::is.Take(); + } + + RAPIDJSON_FORCEINLINE void Push(StackCharacter c) { + stackStream.Put(c); + } + + size_t Length() { return stackStream.Length(); } + + const StackCharacter* Pop() { + stackStream.Put('\0'); + return stackStream.Pop(); + } + + private: + StackStream stackStream; + }; + + template + class NumberStream : public NumberStream { + typedef NumberStream Base; + public: + NumberStream(GenericReader& reader, InputStream& s) : Base(reader, s) {} + + RAPIDJSON_FORCEINLINE Ch Take() { return Base::TakePush(); } + }; + + template + void ParseNumber(InputStream& is, Handler& handler) { + typedef typename internal::SelectIf, typename TargetEncoding::Ch, char>::Type NumberCharacter; + + internal::StreamLocalCopy copy(is); + NumberStream s(*this, copy.s); + + size_t startOffset = s.Tell(); + double d = 0.0; + bool useNanOrInf = false; + + // Parse minus + bool minus = Consume(s, '-'); + + // Parse int: zero / ( digit1-9 *DIGIT ) + unsigned i = 0; + uint64_t i64 = 0; + bool use64bit = false; + int significandDigit = 0; + if (RAPIDJSON_UNLIKELY(s.Peek() == '0')) { + i = 0; + s.TakePush(); + } + else if (RAPIDJSON_LIKELY(s.Peek() >= '1' && s.Peek() <= '9')) { + i = static_cast(s.TakePush() - '0'); + + if (minus) + while (RAPIDJSON_LIKELY(s.Peek() >= '0' && s.Peek() <= '9')) { + if (RAPIDJSON_UNLIKELY(i >= 214748364)) { // 2^31 = 2147483648 + if (RAPIDJSON_LIKELY(i != 214748364 || s.Peek() > '8')) { + i64 = i; + use64bit = true; + break; + } + } + i = i * 10 + static_cast(s.TakePush() - '0'); + significandDigit++; + } + else + while (RAPIDJSON_LIKELY(s.Peek() >= '0' && s.Peek() <= '9')) { + if (RAPIDJSON_UNLIKELY(i >= 429496729)) { // 2^32 - 1 = 4294967295 + if (RAPIDJSON_LIKELY(i != 429496729 || s.Peek() > '5')) { + i64 = i; + use64bit = true; + break; + } + } + i = i * 10 + static_cast(s.TakePush() - '0'); + significandDigit++; + } + } + // Parse NaN or Infinity here + else if ((parseFlags & kParseNanAndInfFlag) && RAPIDJSON_LIKELY((s.Peek() == 'I' || s.Peek() == 'N'))) { + if (Consume(s, 'N')) { + if (Consume(s, 'a') && Consume(s, 'N')) { + d = std::numeric_limits::quiet_NaN(); + useNanOrInf = true; + } + } + else if (RAPIDJSON_LIKELY(Consume(s, 'I'))) { + if (Consume(s, 'n') && Consume(s, 'f')) { + d = (minus ? -std::numeric_limits::infinity() : std::numeric_limits::infinity()); + useNanOrInf = true; + + if (RAPIDJSON_UNLIKELY(s.Peek() == 'i' && !(Consume(s, 'i') && Consume(s, 'n') + && Consume(s, 'i') && Consume(s, 't') && Consume(s, 'y')))) { + RAPIDJSON_PARSE_ERROR(kParseErrorValueInvalid, s.Tell()); + } + } + } + + if (RAPIDJSON_UNLIKELY(!useNanOrInf)) { + RAPIDJSON_PARSE_ERROR(kParseErrorValueInvalid, s.Tell()); + } + } + else + RAPIDJSON_PARSE_ERROR(kParseErrorValueInvalid, s.Tell()); + + // Parse 64bit int + bool useDouble = false; + if (use64bit) { + if (minus) + while (RAPIDJSON_LIKELY(s.Peek() >= '0' && s.Peek() <= '9')) { + if (RAPIDJSON_UNLIKELY(i64 >= RAPIDJSON_UINT64_C2(0x0CCCCCCC, 0xCCCCCCCC))) // 2^63 = 9223372036854775808 + if (RAPIDJSON_LIKELY(i64 != RAPIDJSON_UINT64_C2(0x0CCCCCCC, 0xCCCCCCCC) || s.Peek() > '8')) { + d = static_cast(i64); + useDouble = true; + break; + } + i64 = i64 * 10 + static_cast(s.TakePush() - '0'); + significandDigit++; + } + else + while (RAPIDJSON_LIKELY(s.Peek() >= '0' && s.Peek() <= '9')) { + if (RAPIDJSON_UNLIKELY(i64 >= RAPIDJSON_UINT64_C2(0x19999999, 0x99999999))) // 2^64 - 1 = 18446744073709551615 + if (RAPIDJSON_LIKELY(i64 != RAPIDJSON_UINT64_C2(0x19999999, 0x99999999) || s.Peek() > '5')) { + d = static_cast(i64); + useDouble = true; + break; + } + i64 = i64 * 10 + static_cast(s.TakePush() - '0'); + significandDigit++; + } + } + + // Force double for big integer + if (useDouble) { + while (RAPIDJSON_LIKELY(s.Peek() >= '0' && s.Peek() <= '9')) { + d = d * 10 + (s.TakePush() - '0'); + } + } + + // Parse frac = decimal-point 1*DIGIT + int expFrac = 0; + size_t decimalPosition; + if (!useNanOrInf && Consume(s, '.')) { + decimalPosition = s.Length(); + + if (RAPIDJSON_UNLIKELY(!(s.Peek() >= '0' && s.Peek() <= '9'))) + RAPIDJSON_PARSE_ERROR(kParseErrorNumberMissFraction, s.Tell()); + + if (!useDouble) { +#if RAPIDJSON_64BIT + // Use i64 to store significand in 64-bit architecture + if (!use64bit) + i64 = i; + + while (RAPIDJSON_LIKELY(s.Peek() >= '0' && s.Peek() <= '9')) { + if (i64 > RAPIDJSON_UINT64_C2(0x1FFFFF, 0xFFFFFFFF)) // 2^53 - 1 for fast path + break; + else { + i64 = i64 * 10 + static_cast(s.TakePush() - '0'); + --expFrac; + if (i64 != 0) + significandDigit++; + } + } + + d = static_cast(i64); +#else + // Use double to store significand in 32-bit architecture + d = static_cast(use64bit ? i64 : i); +#endif + useDouble = true; + } + + while (RAPIDJSON_LIKELY(s.Peek() >= '0' && s.Peek() <= '9')) { + if (significandDigit < 17) { + d = d * 10.0 + (s.TakePush() - '0'); + --expFrac; + if (RAPIDJSON_LIKELY(d > 0.0)) + significandDigit++; + } + else + s.TakePush(); + } + } + else + decimalPosition = s.Length(); // decimal position at the end of integer. + + // Parse exp = e [ minus / plus ] 1*DIGIT + int exp = 0; + if (!useNanOrInf && (Consume(s, 'e') || Consume(s, 'E'))) { + if (!useDouble) { + d = static_cast(use64bit ? i64 : i); + useDouble = true; + } + + bool expMinus = false; + if (Consume(s, '+')) + ; + else if (Consume(s, '-')) + expMinus = true; + + if (RAPIDJSON_LIKELY(s.Peek() >= '0' && s.Peek() <= '9')) { + exp = static_cast(s.Take() - '0'); + if (expMinus) { + // (exp + expFrac) must not underflow int => we're detecting when -exp gets + // dangerously close to INT_MIN (a pessimistic next digit 9 would push it into + // underflow territory): + // + // -(exp * 10 + 9) + expFrac >= INT_MIN + // <=> exp <= (expFrac - INT_MIN - 9) / 10 + RAPIDJSON_ASSERT(expFrac <= 0); + int maxExp = (expFrac + 2147483639) / 10; + + while (RAPIDJSON_LIKELY(s.Peek() >= '0' && s.Peek() <= '9')) { + exp = exp * 10 + static_cast(s.Take() - '0'); + if (RAPIDJSON_UNLIKELY(exp > maxExp)) { + while (RAPIDJSON_UNLIKELY(s.Peek() >= '0' && s.Peek() <= '9')) // Consume the rest of exponent + s.Take(); + } + } + } + else { // positive exp + int maxExp = 308 - expFrac; + while (RAPIDJSON_LIKELY(s.Peek() >= '0' && s.Peek() <= '9')) { + exp = exp * 10 + static_cast(s.Take() - '0'); + if (RAPIDJSON_UNLIKELY(exp > maxExp)) + RAPIDJSON_PARSE_ERROR(kParseErrorNumberTooBig, startOffset); + } + } + } + else + RAPIDJSON_PARSE_ERROR(kParseErrorNumberMissExponent, s.Tell()); + + if (expMinus) + exp = -exp; + } + + // Finish parsing, call event according to the type of number. + bool cont = true; + + if (parseFlags & kParseNumbersAsStringsFlag) { + if (parseFlags & kParseInsituFlag) { + s.Pop(); // Pop stack no matter if it will be used or not. + typename InputStream::Ch* head = is.PutBegin(); + const size_t length = s.Tell() - startOffset; + RAPIDJSON_ASSERT(length <= 0xFFFFFFFF); + // unable to insert the \0 character here, it will erase the comma after this number + const typename TargetEncoding::Ch* const str = reinterpret_cast(head); + cont = handler.RawNumber(str, SizeType(length), false); + } + else { + SizeType numCharsToCopy = static_cast(s.Length()); + GenericStringStream > srcStream(s.Pop()); + StackStream dstStream(stack_); + while (numCharsToCopy--) { + Transcoder, TargetEncoding>::Transcode(srcStream, dstStream); + } + dstStream.Put('\0'); + const typename TargetEncoding::Ch* str = dstStream.Pop(); + const SizeType length = static_cast(dstStream.Length()) - 1; + cont = handler.RawNumber(str, SizeType(length), true); + } + } + else { + size_t length = s.Length(); + const NumberCharacter* decimal = s.Pop(); // Pop stack no matter if it will be used or not. + + if (useDouble) { + int p = exp + expFrac; + if (parseFlags & kParseFullPrecisionFlag) + d = internal::StrtodFullPrecision(d, p, decimal, length, decimalPosition, exp); + else + d = internal::StrtodNormalPrecision(d, p); + + // Use > max, instead of == inf, to fix bogus warning -Wfloat-equal + if (d > (std::numeric_limits::max)()) { + // Overflow + // TODO: internal::StrtodX should report overflow (or underflow) + RAPIDJSON_PARSE_ERROR(kParseErrorNumberTooBig, startOffset); + } + + cont = handler.Double(minus ? -d : d); + } + else if (useNanOrInf) { + cont = handler.Double(d); + } + else { + if (use64bit) { + if (minus) + cont = handler.Int64(static_cast(~i64 + 1)); + else + cont = handler.Uint64(i64); + } + else { + if (minus) + cont = handler.Int(static_cast(~i + 1)); + else + cont = handler.Uint(i); + } + } + } + if (RAPIDJSON_UNLIKELY(!cont)) + RAPIDJSON_PARSE_ERROR(kParseErrorTermination, startOffset); + } + + // Parse any JSON value + template + void ParseValue(InputStream& is, Handler& handler) { + switch (is.Peek()) { + case 'n': ParseNull (is, handler); break; + case 't': ParseTrue (is, handler); break; + case 'f': ParseFalse (is, handler); break; + case '"': ParseString(is, handler); break; + case '{': ParseObject(is, handler); break; + case '[': ParseArray (is, handler); break; + default : + ParseNumber(is, handler); + break; + + } + } + + // Iterative Parsing + + // States + enum IterativeParsingState { + IterativeParsingFinishState = 0, // sink states at top + IterativeParsingErrorState, // sink states at top + IterativeParsingStartState, + + // Object states + IterativeParsingObjectInitialState, + IterativeParsingMemberKeyState, + IterativeParsingMemberValueState, + IterativeParsingObjectFinishState, + + // Array states + IterativeParsingArrayInitialState, + IterativeParsingElementState, + IterativeParsingArrayFinishState, + + // Single value state + IterativeParsingValueState, + + // Delimiter states (at bottom) + IterativeParsingElementDelimiterState, + IterativeParsingMemberDelimiterState, + IterativeParsingKeyValueDelimiterState, + + cIterativeParsingStateCount + }; + + // Tokens + enum Token { + LeftBracketToken = 0, + RightBracketToken, + + LeftCurlyBracketToken, + RightCurlyBracketToken, + + CommaToken, + ColonToken, + + StringToken, + FalseToken, + TrueToken, + NullToken, + NumberToken, + + kTokenCount + }; + + RAPIDJSON_FORCEINLINE Token Tokenize(Ch c) const { + +//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN +#define N NumberToken +#define N16 N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N + // Maps from ASCII to Token + static const unsigned char tokenMap[256] = { + N16, // 00~0F + N16, // 10~1F + N, N, StringToken, N, N, N, N, N, N, N, N, N, CommaToken, N, N, N, // 20~2F + N, N, N, N, N, N, N, N, N, N, ColonToken, N, N, N, N, N, // 30~3F + N16, // 40~4F + N, N, N, N, N, N, N, N, N, N, N, LeftBracketToken, N, RightBracketToken, N, N, // 50~5F + N, N, N, N, N, N, FalseToken, N, N, N, N, N, N, N, NullToken, N, // 60~6F + N, N, N, N, TrueToken, N, N, N, N, N, N, LeftCurlyBracketToken, N, RightCurlyBracketToken, N, N, // 70~7F + N16, N16, N16, N16, N16, N16, N16, N16 // 80~FF + }; +#undef N +#undef N16 +//!@endcond + + if (sizeof(Ch) == 1 || static_cast(c) < 256) + return static_cast(tokenMap[static_cast(c)]); + else + return NumberToken; + } + + RAPIDJSON_FORCEINLINE IterativeParsingState Predict(IterativeParsingState state, Token token) const { + // current state x one lookahead token -> new state + static const char G[cIterativeParsingStateCount][kTokenCount] = { + // Finish(sink state) + { + IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, + IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, + IterativeParsingErrorState + }, + // Error(sink state) + { + IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, + IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, + IterativeParsingErrorState + }, + // Start + { + IterativeParsingArrayInitialState, // Left bracket + IterativeParsingErrorState, // Right bracket + IterativeParsingObjectInitialState, // Left curly bracket + IterativeParsingErrorState, // Right curly bracket + IterativeParsingErrorState, // Comma + IterativeParsingErrorState, // Colon + IterativeParsingValueState, // String + IterativeParsingValueState, // False + IterativeParsingValueState, // True + IterativeParsingValueState, // Null + IterativeParsingValueState // Number + }, + // ObjectInitial + { + IterativeParsingErrorState, // Left bracket + IterativeParsingErrorState, // Right bracket + IterativeParsingErrorState, // Left curly bracket + IterativeParsingObjectFinishState, // Right curly bracket + IterativeParsingErrorState, // Comma + IterativeParsingErrorState, // Colon + IterativeParsingMemberKeyState, // String + IterativeParsingErrorState, // False + IterativeParsingErrorState, // True + IterativeParsingErrorState, // Null + IterativeParsingErrorState // Number + }, + // MemberKey + { + IterativeParsingErrorState, // Left bracket + IterativeParsingErrorState, // Right bracket + IterativeParsingErrorState, // Left curly bracket + IterativeParsingErrorState, // Right curly bracket + IterativeParsingErrorState, // Comma + IterativeParsingKeyValueDelimiterState, // Colon + IterativeParsingErrorState, // String + IterativeParsingErrorState, // False + IterativeParsingErrorState, // True + IterativeParsingErrorState, // Null + IterativeParsingErrorState // Number + }, + // MemberValue + { + IterativeParsingErrorState, // Left bracket + IterativeParsingErrorState, // Right bracket + IterativeParsingErrorState, // Left curly bracket + IterativeParsingObjectFinishState, // Right curly bracket + IterativeParsingMemberDelimiterState, // Comma + IterativeParsingErrorState, // Colon + IterativeParsingErrorState, // String + IterativeParsingErrorState, // False + IterativeParsingErrorState, // True + IterativeParsingErrorState, // Null + IterativeParsingErrorState // Number + }, + // ObjectFinish(sink state) + { + IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, + IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, + IterativeParsingErrorState + }, + // ArrayInitial + { + IterativeParsingArrayInitialState, // Left bracket(push Element state) + IterativeParsingArrayFinishState, // Right bracket + IterativeParsingObjectInitialState, // Left curly bracket(push Element state) + IterativeParsingErrorState, // Right curly bracket + IterativeParsingErrorState, // Comma + IterativeParsingErrorState, // Colon + IterativeParsingElementState, // String + IterativeParsingElementState, // False + IterativeParsingElementState, // True + IterativeParsingElementState, // Null + IterativeParsingElementState // Number + }, + // Element + { + IterativeParsingErrorState, // Left bracket + IterativeParsingArrayFinishState, // Right bracket + IterativeParsingErrorState, // Left curly bracket + IterativeParsingErrorState, // Right curly bracket + IterativeParsingElementDelimiterState, // Comma + IterativeParsingErrorState, // Colon + IterativeParsingErrorState, // String + IterativeParsingErrorState, // False + IterativeParsingErrorState, // True + IterativeParsingErrorState, // Null + IterativeParsingErrorState // Number + }, + // ArrayFinish(sink state) + { + IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, + IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, + IterativeParsingErrorState + }, + // Single Value (sink state) + { + IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, + IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, IterativeParsingErrorState, + IterativeParsingErrorState + }, + // ElementDelimiter + { + IterativeParsingArrayInitialState, // Left bracket(push Element state) + IterativeParsingArrayFinishState, // Right bracket + IterativeParsingObjectInitialState, // Left curly bracket(push Element state) + IterativeParsingErrorState, // Right curly bracket + IterativeParsingErrorState, // Comma + IterativeParsingErrorState, // Colon + IterativeParsingElementState, // String + IterativeParsingElementState, // False + IterativeParsingElementState, // True + IterativeParsingElementState, // Null + IterativeParsingElementState // Number + }, + // MemberDelimiter + { + IterativeParsingErrorState, // Left bracket + IterativeParsingErrorState, // Right bracket + IterativeParsingErrorState, // Left curly bracket + IterativeParsingObjectFinishState, // Right curly bracket + IterativeParsingErrorState, // Comma + IterativeParsingErrorState, // Colon + IterativeParsingMemberKeyState, // String + IterativeParsingErrorState, // False + IterativeParsingErrorState, // True + IterativeParsingErrorState, // Null + IterativeParsingErrorState // Number + }, + // KeyValueDelimiter + { + IterativeParsingArrayInitialState, // Left bracket(push MemberValue state) + IterativeParsingErrorState, // Right bracket + IterativeParsingObjectInitialState, // Left curly bracket(push MemberValue state) + IterativeParsingErrorState, // Right curly bracket + IterativeParsingErrorState, // Comma + IterativeParsingErrorState, // Colon + IterativeParsingMemberValueState, // String + IterativeParsingMemberValueState, // False + IterativeParsingMemberValueState, // True + IterativeParsingMemberValueState, // Null + IterativeParsingMemberValueState // Number + }, + }; // End of G + + return static_cast(G[state][token]); + } + + // Make an advance in the token stream and state based on the candidate destination state which was returned by Transit(). + // May return a new state on state pop. + template + RAPIDJSON_FORCEINLINE IterativeParsingState Transit(IterativeParsingState src, Token token, IterativeParsingState dst, InputStream& is, Handler& handler) { + (void)token; + + switch (dst) { + case IterativeParsingErrorState: + return dst; + + case IterativeParsingObjectInitialState: + case IterativeParsingArrayInitialState: + { + // Push the state(Element or MemeberValue) if we are nested in another array or value of member. + // In this way we can get the correct state on ObjectFinish or ArrayFinish by frame pop. + IterativeParsingState n = src; + if (src == IterativeParsingArrayInitialState || src == IterativeParsingElementDelimiterState) + n = IterativeParsingElementState; + else if (src == IterativeParsingKeyValueDelimiterState) + n = IterativeParsingMemberValueState; + // Push current state. + *stack_.template Push(1) = n; + // Initialize and push the member/element count. + *stack_.template Push(1) = 0; + // Call handler + bool hr = (dst == IterativeParsingObjectInitialState) ? handler.StartObject() : handler.StartArray(); + // On handler short circuits the parsing. + if (!hr) { + RAPIDJSON_PARSE_ERROR_NORETURN(kParseErrorTermination, is.Tell()); + return IterativeParsingErrorState; + } + else { + is.Take(); + return dst; + } + } + + case IterativeParsingMemberKeyState: + ParseString(is, handler, true); + if (HasParseError()) + return IterativeParsingErrorState; + else + return dst; + + case IterativeParsingKeyValueDelimiterState: + RAPIDJSON_ASSERT(token == ColonToken); + is.Take(); + return dst; + + case IterativeParsingMemberValueState: + // Must be non-compound value. Or it would be ObjectInitial or ArrayInitial state. + ParseValue(is, handler); + if (HasParseError()) { + return IterativeParsingErrorState; + } + return dst; + + case IterativeParsingElementState: + // Must be non-compound value. Or it would be ObjectInitial or ArrayInitial state. + ParseValue(is, handler); + if (HasParseError()) { + return IterativeParsingErrorState; + } + return dst; + + case IterativeParsingMemberDelimiterState: + case IterativeParsingElementDelimiterState: + is.Take(); + // Update member/element count. + *stack_.template Top() = *stack_.template Top() + 1; + return dst; + + case IterativeParsingObjectFinishState: + { + // Transit from delimiter is only allowed when trailing commas are enabled + if (!(parseFlags & kParseTrailingCommasFlag) && src == IterativeParsingMemberDelimiterState) { + RAPIDJSON_PARSE_ERROR_NORETURN(kParseErrorObjectMissName, is.Tell()); + return IterativeParsingErrorState; + } + // Get member count. + SizeType c = *stack_.template Pop(1); + // If the object is not empty, count the last member. + if (src == IterativeParsingMemberValueState) + ++c; + // Restore the state. + IterativeParsingState n = static_cast(*stack_.template Pop(1)); + // Transit to Finish state if this is the topmost scope. + if (n == IterativeParsingStartState) + n = IterativeParsingFinishState; + // Call handler + bool hr = handler.EndObject(c); + // On handler short circuits the parsing. + if (!hr) { + RAPIDJSON_PARSE_ERROR_NORETURN(kParseErrorTermination, is.Tell()); + return IterativeParsingErrorState; + } + else { + is.Take(); + return n; + } + } + + case IterativeParsingArrayFinishState: + { + // Transit from delimiter is only allowed when trailing commas are enabled + if (!(parseFlags & kParseTrailingCommasFlag) && src == IterativeParsingElementDelimiterState) { + RAPIDJSON_PARSE_ERROR_NORETURN(kParseErrorValueInvalid, is.Tell()); + return IterativeParsingErrorState; + } + // Get element count. + SizeType c = *stack_.template Pop(1); + // If the array is not empty, count the last element. + if (src == IterativeParsingElementState) + ++c; + // Restore the state. + IterativeParsingState n = static_cast(*stack_.template Pop(1)); + // Transit to Finish state if this is the topmost scope. + if (n == IterativeParsingStartState) + n = IterativeParsingFinishState; + // Call handler + bool hr = handler.EndArray(c); + // On handler short circuits the parsing. + if (!hr) { + RAPIDJSON_PARSE_ERROR_NORETURN(kParseErrorTermination, is.Tell()); + return IterativeParsingErrorState; + } + else { + is.Take(); + return n; + } + } + + default: + // This branch is for IterativeParsingValueState actually. + // Use `default:` rather than + // `case IterativeParsingValueState:` is for code coverage. + + // The IterativeParsingStartState is not enumerated in this switch-case. + // It is impossible for that case. And it can be caught by following assertion. + + // The IterativeParsingFinishState is not enumerated in this switch-case either. + // It is a "derivative" state which cannot triggered from Predict() directly. + // Therefore it cannot happen here. And it can be caught by following assertion. + RAPIDJSON_ASSERT(dst == IterativeParsingValueState); + + // Must be non-compound value. Or it would be ObjectInitial or ArrayInitial state. + ParseValue(is, handler); + if (HasParseError()) { + return IterativeParsingErrorState; + } + return IterativeParsingFinishState; + } + } + + template + void HandleError(IterativeParsingState src, InputStream& is) { + if (HasParseError()) { + // Error flag has been set. + return; + } + + switch (src) { + case IterativeParsingStartState: RAPIDJSON_PARSE_ERROR(kParseErrorDocumentEmpty, is.Tell()); return; + case IterativeParsingFinishState: RAPIDJSON_PARSE_ERROR(kParseErrorDocumentRootNotSingular, is.Tell()); return; + case IterativeParsingObjectInitialState: + case IterativeParsingMemberDelimiterState: RAPIDJSON_PARSE_ERROR(kParseErrorObjectMissName, is.Tell()); return; + case IterativeParsingMemberKeyState: RAPIDJSON_PARSE_ERROR(kParseErrorObjectMissColon, is.Tell()); return; + case IterativeParsingMemberValueState: RAPIDJSON_PARSE_ERROR(kParseErrorObjectMissCommaOrCurlyBracket, is.Tell()); return; + case IterativeParsingKeyValueDelimiterState: + case IterativeParsingArrayInitialState: + case IterativeParsingElementDelimiterState: RAPIDJSON_PARSE_ERROR(kParseErrorValueInvalid, is.Tell()); return; + default: RAPIDJSON_ASSERT(src == IterativeParsingElementState); RAPIDJSON_PARSE_ERROR(kParseErrorArrayMissCommaOrSquareBracket, is.Tell()); return; + } + } + + RAPIDJSON_FORCEINLINE bool IsIterativeParsingDelimiterState(IterativeParsingState s) const { + return s >= IterativeParsingElementDelimiterState; + } + + RAPIDJSON_FORCEINLINE bool IsIterativeParsingCompleteState(IterativeParsingState s) const { + return s <= IterativeParsingErrorState; + } + + template + ParseResult IterativeParse(InputStream& is, Handler& handler) { + parseResult_.Clear(); + ClearStackOnExit scope(*this); + IterativeParsingState state = IterativeParsingStartState; + + SkipWhitespaceAndComments(is); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN(parseResult_); + while (is.Peek() != '\0') { + Token t = Tokenize(is.Peek()); + IterativeParsingState n = Predict(state, t); + IterativeParsingState d = Transit(state, t, n, is, handler); + + if (d == IterativeParsingErrorState) { + HandleError(state, is); + break; + } + + state = d; + + // Do not further consume streams if a root JSON has been parsed. + if ((parseFlags & kParseStopWhenDoneFlag) && state == IterativeParsingFinishState) + break; + + SkipWhitespaceAndComments(is); + RAPIDJSON_PARSE_ERROR_EARLY_RETURN(parseResult_); + } + + // Handle the end of file. + if (state != IterativeParsingFinishState) + HandleError(state, is); + + return parseResult_; + } + + static const size_t kDefaultStackCapacity = 256; //!< Default stack capacity in bytes for storing a single decoded string. + internal::Stack stack_; //!< A stack for storing decoded string temporarily during non-destructive parsing. + ParseResult parseResult_; + IterativeParsingState state_; +}; // class GenericReader + +//! Reader with UTF8 encoding and default allocator. +typedef GenericReader, UTF8<> > Reader; + +RAPIDJSON_NAMESPACE_END + +#if defined(__clang__) || defined(_MSC_VER) +RAPIDJSON_DIAG_POP +#endif + + +#ifdef __GNUC__ +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_READER_H_ diff --git a/include/rapidjson/schema.h b/include/rapidjson/schema.h new file mode 100644 index 0000000000..f049285f4e --- /dev/null +++ b/include/rapidjson/schema.h @@ -0,0 +1,3261 @@ +// Tencent is pleased to support the open source community by making RapidJSON available-> +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip-> All rights reserved-> +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License-> You may obtain a copy of the License at +// +// http://opensource->org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied-> See the License for the +// specific language governing permissions and limitations under the License-> + +#ifndef RAPIDJSON_SCHEMA_H_ +#define RAPIDJSON_SCHEMA_H_ + +#include "document.h" +#include "pointer.h" +#include "stringbuffer.h" +#include "error/en.h" +#include "uri.h" +#include // abs, floor + +#if !defined(RAPIDJSON_SCHEMA_USE_INTERNALREGEX) +#define RAPIDJSON_SCHEMA_USE_INTERNALREGEX 1 +#endif + +#if !defined(RAPIDJSON_SCHEMA_USE_STDREGEX) || !(__cplusplus >=201103L || (defined(_MSC_VER) && _MSC_VER >= 1800)) +#define RAPIDJSON_SCHEMA_USE_STDREGEX 0 +#endif + +#if RAPIDJSON_SCHEMA_USE_INTERNALREGEX +#include "internal/regex.h" +#elif RAPIDJSON_SCHEMA_USE_STDREGEX +#include +#endif + +#if RAPIDJSON_SCHEMA_USE_INTERNALREGEX || RAPIDJSON_SCHEMA_USE_STDREGEX +#define RAPIDJSON_SCHEMA_HAS_REGEX 1 +#else +#define RAPIDJSON_SCHEMA_HAS_REGEX 0 +#endif + +#ifndef RAPIDJSON_SCHEMA_VERBOSE +#define RAPIDJSON_SCHEMA_VERBOSE 0 +#endif + +RAPIDJSON_DIAG_PUSH + +#if defined(__GNUC__) +RAPIDJSON_DIAG_OFF(effc++) +#endif + +#ifdef __clang__ +RAPIDJSON_DIAG_OFF(weak-vtables) +RAPIDJSON_DIAG_OFF(exit-time-destructors) +RAPIDJSON_DIAG_OFF(c++98-compat-pedantic) +RAPIDJSON_DIAG_OFF(variadic-macros) +#elif defined(_MSC_VER) +RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +/////////////////////////////////////////////////////////////////////////////// +// Verbose Utilities + +#if RAPIDJSON_SCHEMA_VERBOSE + +namespace internal { + +inline void PrintInvalidKeywordData(const char* keyword) { + printf(" Fail keyword: '%s'\n", keyword); +} + +inline void PrintInvalidKeywordData(const wchar_t* keyword) { + wprintf(L" Fail keyword: '%ls'\n", keyword); +} + +inline void PrintInvalidDocumentData(const char* document) { + printf(" Fail document: '%s'\n", document); +} + +inline void PrintInvalidDocumentData(const wchar_t* document) { + wprintf(L" Fail document: '%ls'\n", document); +} + +inline void PrintValidatorPointersData(const char* s, const char* d, unsigned depth) { + printf(" Sch: %*s'%s'\n Doc: %*s'%s'\n", depth * 4, " ", s, depth * 4, " ", d); +} + +inline void PrintValidatorPointersData(const wchar_t* s, const wchar_t* d, unsigned depth) { + wprintf(L" Sch: %*ls'%ls'\n Doc: %*ls'%ls'\n", depth * 4, L" ", s, depth * 4, L" ", d); +} + +inline void PrintSchemaIdsData(const char* base, const char* local, const char* resolved) { + printf(" Resolving id: Base: '%s', Local: '%s', Resolved: '%s'\n", base, local, resolved); +} + +inline void PrintSchemaIdsData(const wchar_t* base, const wchar_t* local, const wchar_t* resolved) { + wprintf(L" Resolving id: Base: '%ls', Local: '%ls', Resolved: '%ls'\n", base, local, resolved); +} + +inline void PrintMethodData(const char* method) { + printf("%s\n", method); +} + +inline void PrintMethodData(const char* method, bool b) { + printf("%s, Data: '%s'\n", method, b ? "true" : "false"); +} + +inline void PrintMethodData(const char* method, int64_t i) { + printf("%s, Data: '%" PRId64 "'\n", method, i); +} + +inline void PrintMethodData(const char* method, uint64_t u) { + printf("%s, Data: '%" PRIu64 "'\n", method, u); +} + +inline void PrintMethodData(const char* method, double d) { + printf("%s, Data: '%lf'\n", method, d); +} + +inline void PrintMethodData(const char* method, const char* s) { + printf("%s, Data: '%s'\n", method, s); +} + +inline void PrintMethodData(const char* method, const wchar_t* s) { + wprintf(L"%hs, Data: '%ls'\n", method, s); +} + +inline void PrintMethodData(const char* method, const char* s1, const char* s2) { + printf("%s, Data: '%s', '%s'\n", method, s1, s2); +} + +inline void PrintMethodData(const char* method, const wchar_t* s1, const wchar_t* s2) { + wprintf(L"%hs, Data: '%ls', '%ls'\n", method, s1, s2); +} + +} // namespace internal + +#endif // RAPIDJSON_SCHEMA_VERBOSE + +#ifndef RAPIDJSON_SCHEMA_PRINT +#if RAPIDJSON_SCHEMA_VERBOSE +#define RAPIDJSON_SCHEMA_PRINT(name, ...) internal::Print##name##Data(__VA_ARGS__) +#else +#define RAPIDJSON_SCHEMA_PRINT(name, ...) +#endif +#endif + +/////////////////////////////////////////////////////////////////////////////// +// RAPIDJSON_INVALID_KEYWORD_RETURN + +#define RAPIDJSON_INVALID_KEYWORD_RETURN(code)\ +RAPIDJSON_MULTILINEMACRO_BEGIN\ + context.invalidCode = code;\ + context.invalidKeyword = SchemaType::GetValidateErrorKeyword(code).GetString();\ + RAPIDJSON_SCHEMA_PRINT(InvalidKeyword, context.invalidKeyword);\ + return false;\ +RAPIDJSON_MULTILINEMACRO_END + +/////////////////////////////////////////////////////////////////////////////// +// ValidateFlag + +/*! \def RAPIDJSON_VALIDATE_DEFAULT_FLAGS + \ingroup RAPIDJSON_CONFIG + \brief User-defined kValidateDefaultFlags definition. + + User can define this as any \c ValidateFlag combinations. +*/ +#ifndef RAPIDJSON_VALIDATE_DEFAULT_FLAGS +#define RAPIDJSON_VALIDATE_DEFAULT_FLAGS kValidateNoFlags +#endif + +//! Combination of validate flags +enum ValidateFlag { + kValidateNoFlags = 0, //!< No flags are set. + kValidateContinueOnErrorFlag = 1, //!< Don't stop after first validation error. + kValidateReadFlag = 2, //!< Validation is for a read semantic. + kValidateWriteFlag = 4, //!< Validation is for a write semantic. + kValidateDefaultFlags = RAPIDJSON_VALIDATE_DEFAULT_FLAGS //!< Default validate flags. Can be customized by defining RAPIDJSON_VALIDATE_DEFAULT_FLAGS +}; + +/////////////////////////////////////////////////////////////////////////////// +// Specification +enum SchemaDraft { + kDraftUnknown = -1, + kDraftNone = 0, + kDraft03 = 3, + kDraftMin = 4, //!< Current minimum supported draft + kDraft04 = 4, + kDraft05 = 5, + kDraftMax = 5, //!< Current maximum supported draft + kDraft06 = 6, + kDraft07 = 7, + kDraft2019_09 = 8, + kDraft2020_12 = 9 +}; + +enum OpenApiVersion { + kVersionUnknown = -1, + kVersionNone = 0, + kVersionMin = 2, //!< Current minimum supported version + kVersion20 = 2, + kVersion30 = 3, + kVersionMax = 3, //!< Current maximum supported version + kVersion31 = 4, +}; + +struct Specification { + Specification(SchemaDraft d) : draft(d), oapi(kVersionNone) {} + Specification(OpenApiVersion o) : oapi(o) { + if (oapi == kVersion20) draft = kDraft04; + else if (oapi == kVersion30) draft = kDraft05; + else if (oapi == kVersion31) draft = kDraft2020_12; + else draft = kDraft04; + } + ~Specification() {} + bool IsSupported() const { + return ((draft >= kDraftMin && draft <= kDraftMax) && ((oapi == kVersionNone) || (oapi >= kVersionMin && oapi <= kVersionMax))); + } + SchemaDraft draft; + OpenApiVersion oapi; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Forward declarations + +template +class GenericSchemaDocument; + +namespace internal { + +template +class Schema; + +/////////////////////////////////////////////////////////////////////////////// +// ISchemaValidator + +class ISchemaValidator { +public: + virtual ~ISchemaValidator() {} + virtual bool IsValid() const = 0; + virtual void SetValidateFlags(unsigned flags) = 0; + virtual unsigned GetValidateFlags() const = 0; +}; + +/////////////////////////////////////////////////////////////////////////////// +// ISchemaStateFactory + +template +class ISchemaStateFactory { +public: + virtual ~ISchemaStateFactory() {} + virtual ISchemaValidator* CreateSchemaValidator(const SchemaType&, const bool inheritContinueOnErrors) = 0; + virtual void DestroySchemaValidator(ISchemaValidator* validator) = 0; + virtual void* CreateHasher() = 0; + virtual uint64_t GetHashCode(void* hasher) = 0; + virtual void DestroryHasher(void* hasher) = 0; + virtual void* MallocState(size_t size) = 0; + virtual void FreeState(void* p) = 0; +}; + +/////////////////////////////////////////////////////////////////////////////// +// IValidationErrorHandler + +template +class IValidationErrorHandler { +public: + typedef typename SchemaType::Ch Ch; + typedef typename SchemaType::SValue SValue; + + virtual ~IValidationErrorHandler() {} + + virtual void NotMultipleOf(int64_t actual, const SValue& expected) = 0; + virtual void NotMultipleOf(uint64_t actual, const SValue& expected) = 0; + virtual void NotMultipleOf(double actual, const SValue& expected) = 0; + virtual void AboveMaximum(int64_t actual, const SValue& expected, bool exclusive) = 0; + virtual void AboveMaximum(uint64_t actual, const SValue& expected, bool exclusive) = 0; + virtual void AboveMaximum(double actual, const SValue& expected, bool exclusive) = 0; + virtual void BelowMinimum(int64_t actual, const SValue& expected, bool exclusive) = 0; + virtual void BelowMinimum(uint64_t actual, const SValue& expected, bool exclusive) = 0; + virtual void BelowMinimum(double actual, const SValue& expected, bool exclusive) = 0; + + virtual void TooLong(const Ch* str, SizeType length, SizeType expected) = 0; + virtual void TooShort(const Ch* str, SizeType length, SizeType expected) = 0; + virtual void DoesNotMatch(const Ch* str, SizeType length) = 0; + + virtual void DisallowedItem(SizeType index) = 0; + virtual void TooFewItems(SizeType actualCount, SizeType expectedCount) = 0; + virtual void TooManyItems(SizeType actualCount, SizeType expectedCount) = 0; + virtual void DuplicateItems(SizeType index1, SizeType index2) = 0; + + virtual void TooManyProperties(SizeType actualCount, SizeType expectedCount) = 0; + virtual void TooFewProperties(SizeType actualCount, SizeType expectedCount) = 0; + virtual void StartMissingProperties() = 0; + virtual void AddMissingProperty(const SValue& name) = 0; + virtual bool EndMissingProperties() = 0; + virtual void PropertyViolations(ISchemaValidator** subvalidators, SizeType count) = 0; + virtual void DisallowedProperty(const Ch* name, SizeType length) = 0; + + virtual void StartDependencyErrors() = 0; + virtual void StartMissingDependentProperties() = 0; + virtual void AddMissingDependentProperty(const SValue& targetName) = 0; + virtual void EndMissingDependentProperties(const SValue& sourceName) = 0; + virtual void AddDependencySchemaError(const SValue& souceName, ISchemaValidator* subvalidator) = 0; + virtual bool EndDependencyErrors() = 0; + + virtual void DisallowedValue(const ValidateErrorCode code) = 0; + virtual void StartDisallowedType() = 0; + virtual void AddExpectedType(const typename SchemaType::ValueType& expectedType) = 0; + virtual void EndDisallowedType(const typename SchemaType::ValueType& actualType) = 0; + virtual void NotAllOf(ISchemaValidator** subvalidators, SizeType count) = 0; + virtual void NoneOf(ISchemaValidator** subvalidators, SizeType count) = 0; + virtual void NotOneOf(ISchemaValidator** subvalidators, SizeType count) = 0; + virtual void MultipleOneOf(SizeType index1, SizeType index2) = 0; + virtual void Disallowed() = 0; + virtual void DisallowedWhenWriting() = 0; + virtual void DisallowedWhenReading() = 0; +}; + + +/////////////////////////////////////////////////////////////////////////////// +// Hasher + +// For comparison of compound value +template +class Hasher { +public: + typedef typename Encoding::Ch Ch; + + Hasher(Allocator* allocator = 0, size_t stackCapacity = kDefaultSize) : stack_(allocator, stackCapacity) {} + + bool Null() { return WriteType(kNullType); } + bool Bool(bool b) { return WriteType(b ? kTrueType : kFalseType); } + bool Int(int i) { Number n; n.u.i = i; n.d = static_cast(i); return WriteNumber(n); } + bool Uint(unsigned u) { Number n; n.u.u = u; n.d = static_cast(u); return WriteNumber(n); } + bool Int64(int64_t i) { Number n; n.u.i = i; n.d = static_cast(i); return WriteNumber(n); } + bool Uint64(uint64_t u) { Number n; n.u.u = u; n.d = static_cast(u); return WriteNumber(n); } + bool Double(double d) { + Number n; + if (d < 0) n.u.i = static_cast(d); + else n.u.u = static_cast(d); + n.d = d; + return WriteNumber(n); + } + + bool RawNumber(const Ch* str, SizeType len, bool) { + WriteBuffer(kNumberType, str, len * sizeof(Ch)); + return true; + } + + bool String(const Ch* str, SizeType len, bool) { + WriteBuffer(kStringType, str, len * sizeof(Ch)); + return true; + } + + bool StartObject() { return true; } + bool Key(const Ch* str, SizeType len, bool copy) { return String(str, len, copy); } + bool EndObject(SizeType memberCount) { + uint64_t h = Hash(0, kObjectType); + uint64_t* kv = stack_.template Pop(memberCount * 2); + for (SizeType i = 0; i < memberCount; i++) + // Issue #2205 + // Hasing the key to avoid key=value cases with bug-prone zero-value hash + h ^= Hash(Hash(0, kv[i * 2]), kv[i * 2 + 1]); // Use xor to achieve member order insensitive + *stack_.template Push() = h; + return true; + } + + bool StartArray() { return true; } + bool EndArray(SizeType elementCount) { + uint64_t h = Hash(0, kArrayType); + uint64_t* e = stack_.template Pop(elementCount); + for (SizeType i = 0; i < elementCount; i++) + h = Hash(h, e[i]); // Use hash to achieve element order sensitive + *stack_.template Push() = h; + return true; + } + + bool IsValid() const { return stack_.GetSize() == sizeof(uint64_t); } + + uint64_t GetHashCode() const { + RAPIDJSON_ASSERT(IsValid()); + return *stack_.template Top(); + } + +private: + static const size_t kDefaultSize = 256; + struct Number { + union U { + uint64_t u; + int64_t i; + }u; + double d; + }; + + bool WriteType(Type type) { return WriteBuffer(type, 0, 0); } + + bool WriteNumber(const Number& n) { return WriteBuffer(kNumberType, &n, sizeof(n)); } + + bool WriteBuffer(Type type, const void* data, size_t len) { + // FNV-1a from http://isthe.com/chongo/tech/comp/fnv/ + uint64_t h = Hash(RAPIDJSON_UINT64_C2(0xcbf29ce4, 0x84222325), type); + const unsigned char* d = static_cast(data); + for (size_t i = 0; i < len; i++) + h = Hash(h, d[i]); + *stack_.template Push() = h; + return true; + } + + static uint64_t Hash(uint64_t h, uint64_t d) { + static const uint64_t kPrime = RAPIDJSON_UINT64_C2(0x00000100, 0x000001b3); + h ^= d; + h *= kPrime; + return h; + } + + Stack stack_; +}; + +/////////////////////////////////////////////////////////////////////////////// +// SchemaValidationContext + +template +struct SchemaValidationContext { + typedef Schema SchemaType; + typedef ISchemaStateFactory SchemaValidatorFactoryType; + typedef IValidationErrorHandler ErrorHandlerType; + typedef typename SchemaType::ValueType ValueType; + typedef typename ValueType::Ch Ch; + + enum PatternValidatorType { + kPatternValidatorOnly, + kPatternValidatorWithProperty, + kPatternValidatorWithAdditionalProperty + }; + + SchemaValidationContext(SchemaValidatorFactoryType& f, ErrorHandlerType& eh, const SchemaType* s, unsigned fl = 0) : + factory(f), + error_handler(eh), + schema(s), + flags(fl), + valueSchema(), + invalidKeyword(), + invalidCode(), + hasher(), + arrayElementHashCodes(), + validators(), + validatorCount(), + patternPropertiesValidators(), + patternPropertiesValidatorCount(), + patternPropertiesSchemas(), + patternPropertiesSchemaCount(), + valuePatternValidatorType(kPatternValidatorOnly), + propertyExist(), + inArray(false), + valueUniqueness(false), + arrayUniqueness(false) + { + } + + ~SchemaValidationContext() { + if (hasher) + factory.DestroryHasher(hasher); + if (validators) { + for (SizeType i = 0; i < validatorCount; i++) { + if (validators[i]) { + factory.DestroySchemaValidator(validators[i]); + } + } + factory.FreeState(validators); + } + if (patternPropertiesValidators) { + for (SizeType i = 0; i < patternPropertiesValidatorCount; i++) { + if (patternPropertiesValidators[i]) { + factory.DestroySchemaValidator(patternPropertiesValidators[i]); + } + } + factory.FreeState(patternPropertiesValidators); + } + if (patternPropertiesSchemas) + factory.FreeState(patternPropertiesSchemas); + if (propertyExist) + factory.FreeState(propertyExist); + } + + SchemaValidatorFactoryType& factory; + ErrorHandlerType& error_handler; + const SchemaType* schema; + unsigned flags; + const SchemaType* valueSchema; + const Ch* invalidKeyword; + ValidateErrorCode invalidCode; + void* hasher; // Only validator access + void* arrayElementHashCodes; // Only validator access this + ISchemaValidator** validators; + SizeType validatorCount; + ISchemaValidator** patternPropertiesValidators; + SizeType patternPropertiesValidatorCount; + const SchemaType** patternPropertiesSchemas; + SizeType patternPropertiesSchemaCount; + PatternValidatorType valuePatternValidatorType; + PatternValidatorType objectPatternValidatorType; + SizeType arrayElementIndex; + bool* propertyExist; + bool inArray; + bool valueUniqueness; + bool arrayUniqueness; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Schema + +template +class Schema { +public: + typedef typename SchemaDocumentType::ValueType ValueType; + typedef typename SchemaDocumentType::AllocatorType AllocatorType; + typedef typename SchemaDocumentType::PointerType PointerType; + typedef typename ValueType::EncodingType EncodingType; + typedef typename EncodingType::Ch Ch; + typedef SchemaValidationContext Context; + typedef Schema SchemaType; + typedef GenericValue SValue; + typedef IValidationErrorHandler ErrorHandler; + typedef GenericUri UriType; + friend class GenericSchemaDocument; + + Schema(SchemaDocumentType* schemaDocument, const PointerType& p, const ValueType& value, const ValueType& document, AllocatorType* allocator, const UriType& id = UriType()) : + allocator_(allocator), + uri_(schemaDocument->GetURI(), *allocator), + id_(id, allocator), + spec_(schemaDocument->GetSpecification()), + pointer_(p, allocator), + typeless_(schemaDocument->GetTypeless()), + enum_(), + enumCount_(), + not_(), + type_((1 << kTotalSchemaType) - 1), // typeless + validatorCount_(), + notValidatorIndex_(), + properties_(), + additionalPropertiesSchema_(), + patternProperties_(), + patternPropertyCount_(), + propertyCount_(), + minProperties_(), + maxProperties_(SizeType(~0)), + additionalProperties_(true), + hasDependencies_(), + hasRequired_(), + hasSchemaDependencies_(), + additionalItemsSchema_(), + itemsList_(), + itemsTuple_(), + itemsTupleCount_(), + minItems_(), + maxItems_(SizeType(~0)), + additionalItems_(true), + uniqueItems_(false), + pattern_(), + minLength_(0), + maxLength_(~SizeType(0)), + exclusiveMinimum_(false), + exclusiveMaximum_(false), + defaultValueLength_(0), + readOnly_(false), + writeOnly_(false), + nullable_(false) + { + GenericStringBuffer sb; + p.StringifyUriFragment(sb); + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::Schema", sb.GetString(), id.GetString()); + + typedef typename ValueType::ConstValueIterator ConstValueIterator; + typedef typename ValueType::ConstMemberIterator ConstMemberIterator; + + // PR #1393 + // Early add this Schema and its $ref(s) in schemaDocument's map to avoid infinite + // recursion (with recursive schemas), since schemaDocument->getSchema() is always + // checked before creating a new one. Don't cache typeless_, though. + if (this != typeless_) { + typedef typename SchemaDocumentType::SchemaEntry SchemaEntry; + SchemaEntry *entry = schemaDocument->schemaMap_.template Push(); + new (entry) SchemaEntry(pointer_, this, true, allocator_); + schemaDocument->AddSchemaRefs(this); + } + + if (!value.IsObject()) + return; + + // If we have an id property, resolve it with the in-scope id + // Not supported for open api 2.0 or 3.0 + if (spec_.oapi != kVersion20 && spec_.oapi != kVersion30) + if (const ValueType* v = GetMember(value, GetIdString())) { + if (v->IsString()) { + UriType local(*v, allocator); + id_ = local.Resolve(id_, allocator); + RAPIDJSON_SCHEMA_PRINT(SchemaIds, id.GetString(), v->GetString(), id_.GetString()); + } + } + + if (const ValueType* v = GetMember(value, GetTypeString())) { + type_ = 0; + if (v->IsString()) + AddType(*v); + else if (v->IsArray()) + for (ConstValueIterator itr = v->Begin(); itr != v->End(); ++itr) + AddType(*itr); + } + + if (const ValueType* v = GetMember(value, GetEnumString())) { + if (v->IsArray() && v->Size() > 0) { + enum_ = static_cast(allocator_->Malloc(sizeof(uint64_t) * v->Size())); + for (ConstValueIterator itr = v->Begin(); itr != v->End(); ++itr) { + typedef Hasher > EnumHasherType; + char buffer[256u + 24]; + MemoryPoolAllocator hasherAllocator(buffer, sizeof(buffer)); + EnumHasherType h(&hasherAllocator, 256); + itr->Accept(h); + enum_[enumCount_++] = h.GetHashCode(); + } + } + } + + if (schemaDocument) + AssignIfExist(allOf_, *schemaDocument, p, value, GetAllOfString(), document); + + // AnyOf, OneOf, Not not supported for open api 2.0 + if (schemaDocument && spec_.oapi != kVersion20) { + AssignIfExist(anyOf_, *schemaDocument, p, value, GetAnyOfString(), document); + AssignIfExist(oneOf_, *schemaDocument, p, value, GetOneOfString(), document); + + if (const ValueType* v = GetMember(value, GetNotString())) { + schemaDocument->CreateSchema(¬_, p.Append(GetNotString(), allocator_), *v, document, id_); + notValidatorIndex_ = validatorCount_; + validatorCount_++; + } + } + + // Object + + const ValueType* properties = GetMember(value, GetPropertiesString()); + const ValueType* required = GetMember(value, GetRequiredString()); + const ValueType* dependencies = GetMember(value, GetDependenciesString()); + { + // Gather properties from properties/required/dependencies + SValue allProperties(kArrayType); + + if (properties && properties->IsObject()) + for (ConstMemberIterator itr = properties->MemberBegin(); itr != properties->MemberEnd(); ++itr) + AddUniqueElement(allProperties, itr->name); + + if (required && required->IsArray()) + for (ConstValueIterator itr = required->Begin(); itr != required->End(); ++itr) + if (itr->IsString()) + AddUniqueElement(allProperties, *itr); + + // Dependencies not supported for open api 2.0 and 3.0 + if (spec_.oapi != kVersion20 && spec_.oapi != kVersion30) + if (dependencies && dependencies->IsObject()) + for (ConstMemberIterator itr = dependencies->MemberBegin(); itr != dependencies->MemberEnd(); ++itr) { + AddUniqueElement(allProperties, itr->name); + if (itr->value.IsArray()) + for (ConstValueIterator i = itr->value.Begin(); i != itr->value.End(); ++i) + if (i->IsString()) + AddUniqueElement(allProperties, *i); + } + + if (allProperties.Size() > 0) { + propertyCount_ = allProperties.Size(); + properties_ = static_cast(allocator_->Malloc(sizeof(Property) * propertyCount_)); + for (SizeType i = 0; i < propertyCount_; i++) { + new (&properties_[i]) Property(); + properties_[i].name = allProperties[i]; + properties_[i].schema = typeless_; + } + } + } + + if (properties && properties->IsObject()) { + PointerType q = p.Append(GetPropertiesString(), allocator_); + for (ConstMemberIterator itr = properties->MemberBegin(); itr != properties->MemberEnd(); ++itr) { + SizeType index; + if (FindPropertyIndex(itr->name, &index)) + schemaDocument->CreateSchema(&properties_[index].schema, q.Append(itr->name, allocator_), itr->value, document, id_); + } + } + + // PatternProperties not supported for open api 2.0 and 3.0 + if (spec_.oapi != kVersion20 && spec_.oapi != kVersion30) + if (const ValueType* v = GetMember(value, GetPatternPropertiesString())) { + PointerType q = p.Append(GetPatternPropertiesString(), allocator_); + patternProperties_ = static_cast(allocator_->Malloc(sizeof(PatternProperty) * v->MemberCount())); + patternPropertyCount_ = 0; + + for (ConstMemberIterator itr = v->MemberBegin(); itr != v->MemberEnd(); ++itr) { + new (&patternProperties_[patternPropertyCount_]) PatternProperty(); + PointerType r = q.Append(itr->name, allocator_); + patternProperties_[patternPropertyCount_].pattern = CreatePattern(itr->name, schemaDocument, r); + schemaDocument->CreateSchema(&patternProperties_[patternPropertyCount_].schema, r, itr->value, document, id_); + patternPropertyCount_++; + } + } + + if (required && required->IsArray()) + for (ConstValueIterator itr = required->Begin(); itr != required->End(); ++itr) + if (itr->IsString()) { + SizeType index; + if (FindPropertyIndex(*itr, &index)) { + properties_[index].required = true; + hasRequired_ = true; + } + } + + // Dependencies not supported for open api 2.0 and 3.0 + if (spec_.oapi != kVersion20 && spec_.oapi != kVersion30) + if (dependencies && dependencies->IsObject()) { + PointerType q = p.Append(GetDependenciesString(), allocator_); + hasDependencies_ = true; + for (ConstMemberIterator itr = dependencies->MemberBegin(); itr != dependencies->MemberEnd(); ++itr) { + SizeType sourceIndex; + if (FindPropertyIndex(itr->name, &sourceIndex)) { + if (itr->value.IsArray()) { + properties_[sourceIndex].dependencies = static_cast(allocator_->Malloc(sizeof(bool) * propertyCount_)); + std::memset(properties_[sourceIndex].dependencies, 0, sizeof(bool)* propertyCount_); + for (ConstValueIterator targetItr = itr->value.Begin(); targetItr != itr->value.End(); ++targetItr) { + SizeType targetIndex; + if (FindPropertyIndex(*targetItr, &targetIndex)) + properties_[sourceIndex].dependencies[targetIndex] = true; + } + } + else if (itr->value.IsObject()) { + hasSchemaDependencies_ = true; + schemaDocument->CreateSchema(&properties_[sourceIndex].dependenciesSchema, q.Append(itr->name, allocator_), itr->value, document, id_); + properties_[sourceIndex].dependenciesValidatorIndex = validatorCount_; + validatorCount_++; + } + } + } + } + + if (const ValueType* v = GetMember(value, GetAdditionalPropertiesString())) { + if (v->IsBool()) + additionalProperties_ = v->GetBool(); + else if (v->IsObject()) + schemaDocument->CreateSchema(&additionalPropertiesSchema_, p.Append(GetAdditionalPropertiesString(), allocator_), *v, document, id_); + } + + AssignIfExist(minProperties_, value, GetMinPropertiesString()); + AssignIfExist(maxProperties_, value, GetMaxPropertiesString()); + + // Array + if (const ValueType* v = GetMember(value, GetItemsString())) { + PointerType q = p.Append(GetItemsString(), allocator_); + if (v->IsObject()) // List validation + schemaDocument->CreateSchema(&itemsList_, q, *v, document, id_); + else if (v->IsArray()) { // Tuple validation + itemsTuple_ = static_cast(allocator_->Malloc(sizeof(const Schema*) * v->Size())); + SizeType index = 0; + for (ConstValueIterator itr = v->Begin(); itr != v->End(); ++itr, index++) + schemaDocument->CreateSchema(&itemsTuple_[itemsTupleCount_++], q.Append(index, allocator_), *itr, document, id_); + } + } + + AssignIfExist(minItems_, value, GetMinItemsString()); + AssignIfExist(maxItems_, value, GetMaxItemsString()); + + // AdditionalItems not supported for openapi 2.0 and 3.0 + if (spec_.oapi != kVersion20 && spec_.oapi != kVersion30) + if (const ValueType* v = GetMember(value, GetAdditionalItemsString())) { + if (v->IsBool()) + additionalItems_ = v->GetBool(); + else if (v->IsObject()) + schemaDocument->CreateSchema(&additionalItemsSchema_, p.Append(GetAdditionalItemsString(), allocator_), *v, document, id_); + } + + AssignIfExist(uniqueItems_, value, GetUniqueItemsString()); + + // String + AssignIfExist(minLength_, value, GetMinLengthString()); + AssignIfExist(maxLength_, value, GetMaxLengthString()); + + if (const ValueType* v = GetMember(value, GetPatternString())) + pattern_ = CreatePattern(*v, schemaDocument, p.Append(GetPatternString(), allocator_)); + + // Number + if (const ValueType* v = GetMember(value, GetMinimumString())) + if (v->IsNumber()) + minimum_.CopyFrom(*v, *allocator_); + + if (const ValueType* v = GetMember(value, GetMaximumString())) + if (v->IsNumber()) + maximum_.CopyFrom(*v, *allocator_); + + AssignIfExist(exclusiveMinimum_, value, GetExclusiveMinimumString()); + AssignIfExist(exclusiveMaximum_, value, GetExclusiveMaximumString()); + + if (const ValueType* v = GetMember(value, GetMultipleOfString())) + if (v->IsNumber() && v->GetDouble() > 0.0) + multipleOf_.CopyFrom(*v, *allocator_); + + // Default + if (const ValueType* v = GetMember(value, GetDefaultValueString())) + if (v->IsString()) + defaultValueLength_ = v->GetStringLength(); + + // ReadOnly - open api only (until draft 7 supported) + // WriteOnly - open api 3 only (until draft 7 supported) + // Both can't be true + if (spec_.oapi != kVersionNone) + AssignIfExist(readOnly_, value, GetReadOnlyString()); + if (spec_.oapi >= kVersion30) + AssignIfExist(writeOnly_, value, GetWriteOnlyString()); + if (readOnly_ && writeOnly_) + schemaDocument->SchemaError(kSchemaErrorReadOnlyAndWriteOnly, p); + + // Nullable - open api 3 only + // If true add 'null' as allowable type + if (spec_.oapi >= kVersion30) { + AssignIfExist(nullable_, value, GetNullableString()); + if (nullable_) + AddType(GetNullString()); + } + } + + ~Schema() { + AllocatorType::Free(enum_); + if (properties_) { + for (SizeType i = 0; i < propertyCount_; i++) + properties_[i].~Property(); + AllocatorType::Free(properties_); + } + if (patternProperties_) { + for (SizeType i = 0; i < patternPropertyCount_; i++) + patternProperties_[i].~PatternProperty(); + AllocatorType::Free(patternProperties_); + } + AllocatorType::Free(itemsTuple_); +#if RAPIDJSON_SCHEMA_HAS_REGEX + if (pattern_) { + pattern_->~RegexType(); + AllocatorType::Free(pattern_); + } +#endif + } + + const SValue& GetURI() const { + return uri_; + } + + const UriType& GetId() const { + return id_; + } + + const Specification& GetSpecification() const { + return spec_; + } + + const PointerType& GetPointer() const { + return pointer_; + } + + bool BeginValue(Context& context) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::BeginValue"); + if (context.inArray) { + if (uniqueItems_) + context.valueUniqueness = true; + + if (itemsList_) + context.valueSchema = itemsList_; + else if (itemsTuple_) { + if (context.arrayElementIndex < itemsTupleCount_) + context.valueSchema = itemsTuple_[context.arrayElementIndex]; + else if (additionalItemsSchema_) + context.valueSchema = additionalItemsSchema_; + else if (additionalItems_) + context.valueSchema = typeless_; + else { + context.error_handler.DisallowedItem(context.arrayElementIndex); + // Must set valueSchema for when kValidateContinueOnErrorFlag is set, else reports spurious type error + context.valueSchema = typeless_; + // Must bump arrayElementIndex for when kValidateContinueOnErrorFlag is set + context.arrayElementIndex++; + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorAdditionalItems); + } + } + else + context.valueSchema = typeless_; + + context.arrayElementIndex++; + } + return true; + } + + RAPIDJSON_FORCEINLINE bool EndValue(Context& context) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::EndValue"); + // Only check pattern properties if we have validators + if (context.patternPropertiesValidatorCount > 0) { + bool otherValid = false; + SizeType count = context.patternPropertiesValidatorCount; + if (context.objectPatternValidatorType != Context::kPatternValidatorOnly) + otherValid = context.patternPropertiesValidators[--count]->IsValid(); + + bool patternValid = true; + for (SizeType i = 0; i < count; i++) + if (!context.patternPropertiesValidators[i]->IsValid()) { + patternValid = false; + break; + } + + if (context.objectPatternValidatorType == Context::kPatternValidatorOnly) { + if (!patternValid) { + context.error_handler.PropertyViolations(context.patternPropertiesValidators, count); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorPatternProperties); + } + } + else if (context.objectPatternValidatorType == Context::kPatternValidatorWithProperty) { + if (!patternValid || !otherValid) { + context.error_handler.PropertyViolations(context.patternPropertiesValidators, count + 1); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorPatternProperties); + } + } + else if (!patternValid && !otherValid) { // kPatternValidatorWithAdditionalProperty) + context.error_handler.PropertyViolations(context.patternPropertiesValidators, count + 1); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorPatternProperties); + } + } + + // For enums only check if we have a hasher + if (enum_ && context.hasher) { + const uint64_t h = context.factory.GetHashCode(context.hasher); + for (SizeType i = 0; i < enumCount_; i++) + if (enum_[i] == h) + goto foundEnum; + context.error_handler.DisallowedValue(kValidateErrorEnum); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorEnum); + foundEnum:; + } + + // Only check allOf etc if we have validators + if (context.validatorCount > 0) { + if (allOf_.schemas) + for (SizeType i = allOf_.begin; i < allOf_.begin + allOf_.count; i++) + if (!context.validators[i]->IsValid()) { + context.error_handler.NotAllOf(&context.validators[allOf_.begin], allOf_.count); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorAllOf); + } + + if (anyOf_.schemas) { + for (SizeType i = anyOf_.begin; i < anyOf_.begin + anyOf_.count; i++) + if (context.validators[i]->IsValid()) + goto foundAny; + context.error_handler.NoneOf(&context.validators[anyOf_.begin], anyOf_.count); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorAnyOf); + foundAny:; + } + + if (oneOf_.schemas) { + bool oneValid = false; + SizeType firstMatch = 0; + for (SizeType i = oneOf_.begin; i < oneOf_.begin + oneOf_.count; i++) + if (context.validators[i]->IsValid()) { + if (oneValid) { + context.error_handler.MultipleOneOf(firstMatch, i - oneOf_.begin); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorOneOfMatch); + } else { + oneValid = true; + firstMatch = i - oneOf_.begin; + } + } + if (!oneValid) { + context.error_handler.NotOneOf(&context.validators[oneOf_.begin], oneOf_.count); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorOneOf); + } + } + + if (not_ && context.validators[notValidatorIndex_]->IsValid()) { + context.error_handler.Disallowed(); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorNot); + } + } + + return true; + } + + bool Null(Context& context) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::Null"); + if (!(type_ & (1 << kNullSchemaType))) { + DisallowedType(context, GetNullString()); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorType); + } + return CreateParallelValidator(context); + } + + bool Bool(Context& context, bool b) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::Bool", b); + if (!CheckBool(context, b)) + return false; + return CreateParallelValidator(context); + } + + bool Int(Context& context, int i) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::Int", (int64_t)i); + if (!CheckInt(context, i)) + return false; + return CreateParallelValidator(context); + } + + bool Uint(Context& context, unsigned u) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::Uint", (uint64_t)u); + if (!CheckUint(context, u)) + return false; + return CreateParallelValidator(context); + } + + bool Int64(Context& context, int64_t i) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::Int64", i); + if (!CheckInt(context, i)) + return false; + return CreateParallelValidator(context); + } + + bool Uint64(Context& context, uint64_t u) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::Uint64", u); + if (!CheckUint(context, u)) + return false; + return CreateParallelValidator(context); + } + + bool Double(Context& context, double d) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::Double", d); + if (!(type_ & (1 << kNumberSchemaType))) { + DisallowedType(context, GetNumberString()); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorType); + } + + if (!minimum_.IsNull() && !CheckDoubleMinimum(context, d)) + return false; + + if (!maximum_.IsNull() && !CheckDoubleMaximum(context, d)) + return false; + + if (!multipleOf_.IsNull() && !CheckDoubleMultipleOf(context, d)) + return false; + + return CreateParallelValidator(context); + } + + bool String(Context& context, const Ch* str, SizeType length, bool) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::String", str); + if (!(type_ & (1 << kStringSchemaType))) { + DisallowedType(context, GetStringString()); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorType); + } + + if (minLength_ != 0 || maxLength_ != SizeType(~0)) { + SizeType count; + if (internal::CountStringCodePoint(str, length, &count)) { + if (count < minLength_) { + context.error_handler.TooShort(str, length, minLength_); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorMinLength); + } + if (count > maxLength_) { + context.error_handler.TooLong(str, length, maxLength_); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorMaxLength); + } + } + } + + if (pattern_ && !IsPatternMatch(pattern_, str, length)) { + context.error_handler.DoesNotMatch(str, length); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorPattern); + } + + return CreateParallelValidator(context); + } + + bool StartObject(Context& context) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::StartObject"); + if (!(type_ & (1 << kObjectSchemaType))) { + DisallowedType(context, GetObjectString()); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorType); + } + + if (hasDependencies_ || hasRequired_) { + context.propertyExist = static_cast(context.factory.MallocState(sizeof(bool) * propertyCount_)); + std::memset(context.propertyExist, 0, sizeof(bool) * propertyCount_); + } + + if (patternProperties_) { // pre-allocate schema array + SizeType count = patternPropertyCount_ + 1; // extra for valuePatternValidatorType + context.patternPropertiesSchemas = static_cast(context.factory.MallocState(sizeof(const SchemaType*) * count)); + context.patternPropertiesSchemaCount = 0; + std::memset(context.patternPropertiesSchemas, 0, sizeof(SchemaType*) * count); + } + + return CreateParallelValidator(context); + } + + bool Key(Context& context, const Ch* str, SizeType len, bool) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::Key", str); + + if (patternProperties_) { + context.patternPropertiesSchemaCount = 0; + for (SizeType i = 0; i < patternPropertyCount_; i++) + if (patternProperties_[i].pattern && IsPatternMatch(patternProperties_[i].pattern, str, len)) { + context.patternPropertiesSchemas[context.patternPropertiesSchemaCount++] = patternProperties_[i].schema; + context.valueSchema = typeless_; + } + } + + SizeType index = 0; + if (FindPropertyIndex(ValueType(str, len).Move(), &index)) { + if (context.patternPropertiesSchemaCount > 0) { + context.patternPropertiesSchemas[context.patternPropertiesSchemaCount++] = properties_[index].schema; + context.valueSchema = typeless_; + context.valuePatternValidatorType = Context::kPatternValidatorWithProperty; + } + else + context.valueSchema = properties_[index].schema; + + if (context.propertyExist) + context.propertyExist[index] = true; + + return true; + } + + if (additionalPropertiesSchema_) { + if (context.patternPropertiesSchemaCount > 0) { + context.patternPropertiesSchemas[context.patternPropertiesSchemaCount++] = additionalPropertiesSchema_; + context.valueSchema = typeless_; + context.valuePatternValidatorType = Context::kPatternValidatorWithAdditionalProperty; + } + else + context.valueSchema = additionalPropertiesSchema_; + return true; + } + else if (additionalProperties_) { + context.valueSchema = typeless_; + return true; + } + + if (context.patternPropertiesSchemaCount == 0) { // patternProperties are not additional properties + // Must set valueSchema for when kValidateContinueOnErrorFlag is set, else reports spurious type error + context.valueSchema = typeless_; + context.error_handler.DisallowedProperty(str, len); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorAdditionalProperties); + } + + return true; + } + + bool EndObject(Context& context, SizeType memberCount) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::EndObject"); + if (hasRequired_) { + context.error_handler.StartMissingProperties(); + for (SizeType index = 0; index < propertyCount_; index++) + if (properties_[index].required && !context.propertyExist[index]) + if (properties_[index].schema->defaultValueLength_ == 0 ) + context.error_handler.AddMissingProperty(properties_[index].name); + if (context.error_handler.EndMissingProperties()) + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorRequired); + } + + if (memberCount < minProperties_) { + context.error_handler.TooFewProperties(memberCount, minProperties_); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorMinProperties); + } + + if (memberCount > maxProperties_) { + context.error_handler.TooManyProperties(memberCount, maxProperties_); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorMaxProperties); + } + + if (hasDependencies_) { + context.error_handler.StartDependencyErrors(); + for (SizeType sourceIndex = 0; sourceIndex < propertyCount_; sourceIndex++) { + const Property& source = properties_[sourceIndex]; + if (context.propertyExist[sourceIndex]) { + if (source.dependencies) { + context.error_handler.StartMissingDependentProperties(); + for (SizeType targetIndex = 0; targetIndex < propertyCount_; targetIndex++) + if (source.dependencies[targetIndex] && !context.propertyExist[targetIndex]) + context.error_handler.AddMissingDependentProperty(properties_[targetIndex].name); + context.error_handler.EndMissingDependentProperties(source.name); + } + else if (source.dependenciesSchema) { + ISchemaValidator* dependenciesValidator = context.validators[source.dependenciesValidatorIndex]; + if (!dependenciesValidator->IsValid()) + context.error_handler.AddDependencySchemaError(source.name, dependenciesValidator); + } + } + } + if (context.error_handler.EndDependencyErrors()) + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorDependencies); + } + + return true; + } + + bool StartArray(Context& context) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::StartArray"); + context.arrayElementIndex = 0; + context.inArray = true; // Ensure we note that we are in an array + + if (!(type_ & (1 << kArraySchemaType))) { + DisallowedType(context, GetArrayString()); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorType); + } + + return CreateParallelValidator(context); + } + + bool EndArray(Context& context, SizeType elementCount) const { + RAPIDJSON_SCHEMA_PRINT(Method, "Schema::EndArray"); + context.inArray = false; + + if (elementCount < minItems_) { + context.error_handler.TooFewItems(elementCount, minItems_); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorMinItems); + } + + if (elementCount > maxItems_) { + context.error_handler.TooManyItems(elementCount, maxItems_); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorMaxItems); + } + + return true; + } + + static const ValueType& GetValidateErrorKeyword(ValidateErrorCode validateErrorCode) { + switch (validateErrorCode) { + case kValidateErrorMultipleOf: return GetMultipleOfString(); + case kValidateErrorMaximum: return GetMaximumString(); + case kValidateErrorExclusiveMaximum: return GetMaximumString(); // Same + case kValidateErrorMinimum: return GetMinimumString(); + case kValidateErrorExclusiveMinimum: return GetMinimumString(); // Same + + case kValidateErrorMaxLength: return GetMaxLengthString(); + case kValidateErrorMinLength: return GetMinLengthString(); + case kValidateErrorPattern: return GetPatternString(); + + case kValidateErrorMaxItems: return GetMaxItemsString(); + case kValidateErrorMinItems: return GetMinItemsString(); + case kValidateErrorUniqueItems: return GetUniqueItemsString(); + case kValidateErrorAdditionalItems: return GetAdditionalItemsString(); + + case kValidateErrorMaxProperties: return GetMaxPropertiesString(); + case kValidateErrorMinProperties: return GetMinPropertiesString(); + case kValidateErrorRequired: return GetRequiredString(); + case kValidateErrorAdditionalProperties: return GetAdditionalPropertiesString(); + case kValidateErrorPatternProperties: return GetPatternPropertiesString(); + case kValidateErrorDependencies: return GetDependenciesString(); + + case kValidateErrorEnum: return GetEnumString(); + case kValidateErrorType: return GetTypeString(); + + case kValidateErrorOneOf: return GetOneOfString(); + case kValidateErrorOneOfMatch: return GetOneOfString(); // Same + case kValidateErrorAllOf: return GetAllOfString(); + case kValidateErrorAnyOf: return GetAnyOfString(); + case kValidateErrorNot: return GetNotString(); + + case kValidateErrorReadOnly: return GetReadOnlyString(); + case kValidateErrorWriteOnly: return GetWriteOnlyString(); + + default: return GetNullString(); + } + } + + + // Generate functions for string literal according to Ch +#define RAPIDJSON_STRING_(name, ...) \ + static const ValueType& Get##name##String() {\ + static const Ch s[] = { __VA_ARGS__, '\0' };\ + static const ValueType v(s, static_cast(sizeof(s) / sizeof(Ch) - 1));\ + return v;\ + } + + RAPIDJSON_STRING_(Null, 'n', 'u', 'l', 'l') + RAPIDJSON_STRING_(Boolean, 'b', 'o', 'o', 'l', 'e', 'a', 'n') + RAPIDJSON_STRING_(Object, 'o', 'b', 'j', 'e', 'c', 't') + RAPIDJSON_STRING_(Array, 'a', 'r', 'r', 'a', 'y') + RAPIDJSON_STRING_(String, 's', 't', 'r', 'i', 'n', 'g') + RAPIDJSON_STRING_(Number, 'n', 'u', 'm', 'b', 'e', 'r') + RAPIDJSON_STRING_(Integer, 'i', 'n', 't', 'e', 'g', 'e', 'r') + RAPIDJSON_STRING_(Type, 't', 'y', 'p', 'e') + RAPIDJSON_STRING_(Enum, 'e', 'n', 'u', 'm') + RAPIDJSON_STRING_(AllOf, 'a', 'l', 'l', 'O', 'f') + RAPIDJSON_STRING_(AnyOf, 'a', 'n', 'y', 'O', 'f') + RAPIDJSON_STRING_(OneOf, 'o', 'n', 'e', 'O', 'f') + RAPIDJSON_STRING_(Not, 'n', 'o', 't') + RAPIDJSON_STRING_(Properties, 'p', 'r', 'o', 'p', 'e', 'r', 't', 'i', 'e', 's') + RAPIDJSON_STRING_(Required, 'r', 'e', 'q', 'u', 'i', 'r', 'e', 'd') + RAPIDJSON_STRING_(Dependencies, 'd', 'e', 'p', 'e', 'n', 'd', 'e', 'n', 'c', 'i', 'e', 's') + RAPIDJSON_STRING_(PatternProperties, 'p', 'a', 't', 't', 'e', 'r', 'n', 'P', 'r', 'o', 'p', 'e', 'r', 't', 'i', 'e', 's') + RAPIDJSON_STRING_(AdditionalProperties, 'a', 'd', 'd', 'i', 't', 'i', 'o', 'n', 'a', 'l', 'P', 'r', 'o', 'p', 'e', 'r', 't', 'i', 'e', 's') + RAPIDJSON_STRING_(MinProperties, 'm', 'i', 'n', 'P', 'r', 'o', 'p', 'e', 'r', 't', 'i', 'e', 's') + RAPIDJSON_STRING_(MaxProperties, 'm', 'a', 'x', 'P', 'r', 'o', 'p', 'e', 'r', 't', 'i', 'e', 's') + RAPIDJSON_STRING_(Items, 'i', 't', 'e', 'm', 's') + RAPIDJSON_STRING_(MinItems, 'm', 'i', 'n', 'I', 't', 'e', 'm', 's') + RAPIDJSON_STRING_(MaxItems, 'm', 'a', 'x', 'I', 't', 'e', 'm', 's') + RAPIDJSON_STRING_(AdditionalItems, 'a', 'd', 'd', 'i', 't', 'i', 'o', 'n', 'a', 'l', 'I', 't', 'e', 'm', 's') + RAPIDJSON_STRING_(UniqueItems, 'u', 'n', 'i', 'q', 'u', 'e', 'I', 't', 'e', 'm', 's') + RAPIDJSON_STRING_(MinLength, 'm', 'i', 'n', 'L', 'e', 'n', 'g', 't', 'h') + RAPIDJSON_STRING_(MaxLength, 'm', 'a', 'x', 'L', 'e', 'n', 'g', 't', 'h') + RAPIDJSON_STRING_(Pattern, 'p', 'a', 't', 't', 'e', 'r', 'n') + RAPIDJSON_STRING_(Minimum, 'm', 'i', 'n', 'i', 'm', 'u', 'm') + RAPIDJSON_STRING_(Maximum, 'm', 'a', 'x', 'i', 'm', 'u', 'm') + RAPIDJSON_STRING_(ExclusiveMinimum, 'e', 'x', 'c', 'l', 'u', 's', 'i', 'v', 'e', 'M', 'i', 'n', 'i', 'm', 'u', 'm') + RAPIDJSON_STRING_(ExclusiveMaximum, 'e', 'x', 'c', 'l', 'u', 's', 'i', 'v', 'e', 'M', 'a', 'x', 'i', 'm', 'u', 'm') + RAPIDJSON_STRING_(MultipleOf, 'm', 'u', 'l', 't', 'i', 'p', 'l', 'e', 'O', 'f') + RAPIDJSON_STRING_(DefaultValue, 'd', 'e', 'f', 'a', 'u', 'l', 't') + RAPIDJSON_STRING_(Schema, '$', 's', 'c', 'h', 'e', 'm', 'a') + RAPIDJSON_STRING_(Ref, '$', 'r', 'e', 'f') + RAPIDJSON_STRING_(Id, 'i', 'd') + RAPIDJSON_STRING_(Swagger, 's', 'w', 'a', 'g', 'g', 'e', 'r') + RAPIDJSON_STRING_(OpenApi, 'o', 'p', 'e', 'n', 'a', 'p', 'i') + RAPIDJSON_STRING_(ReadOnly, 'r', 'e', 'a', 'd', 'O', 'n', 'l', 'y') + RAPIDJSON_STRING_(WriteOnly, 'w', 'r', 'i', 't', 'e', 'O', 'n', 'l', 'y') + RAPIDJSON_STRING_(Nullable, 'n', 'u', 'l', 'l', 'a', 'b', 'l', 'e') + +#undef RAPIDJSON_STRING_ + +private: + enum SchemaValueType { + kNullSchemaType, + kBooleanSchemaType, + kObjectSchemaType, + kArraySchemaType, + kStringSchemaType, + kNumberSchemaType, + kIntegerSchemaType, + kTotalSchemaType + }; + +#if RAPIDJSON_SCHEMA_USE_INTERNALREGEX + typedef internal::GenericRegex RegexType; +#elif RAPIDJSON_SCHEMA_USE_STDREGEX + typedef std::basic_regex RegexType; +#else + typedef char RegexType; +#endif + + struct SchemaArray { + SchemaArray() : schemas(), count() {} + ~SchemaArray() { AllocatorType::Free(schemas); } + const SchemaType** schemas; + SizeType begin; // begin index of context.validators + SizeType count; + }; + + template + void AddUniqueElement(V1& a, const V2& v) { + for (typename V1::ConstValueIterator itr = a.Begin(); itr != a.End(); ++itr) + if (*itr == v) + return; + V1 c(v, *allocator_); + a.PushBack(c, *allocator_); + } + + static const ValueType* GetMember(const ValueType& value, const ValueType& name) { + typename ValueType::ConstMemberIterator itr = value.FindMember(name); + return itr != value.MemberEnd() ? &(itr->value) : 0; + } + + static void AssignIfExist(bool& out, const ValueType& value, const ValueType& name) { + if (const ValueType* v = GetMember(value, name)) + if (v->IsBool()) + out = v->GetBool(); + } + + static void AssignIfExist(SizeType& out, const ValueType& value, const ValueType& name) { + if (const ValueType* v = GetMember(value, name)) + if (v->IsUint64() && v->GetUint64() <= SizeType(~0)) + out = static_cast(v->GetUint64()); + } + + void AssignIfExist(SchemaArray& out, SchemaDocumentType& schemaDocument, const PointerType& p, const ValueType& value, const ValueType& name, const ValueType& document) { + if (const ValueType* v = GetMember(value, name)) { + if (v->IsArray() && v->Size() > 0) { + PointerType q = p.Append(name, allocator_); + out.count = v->Size(); + out.schemas = static_cast(allocator_->Malloc(out.count * sizeof(const Schema*))); + memset(out.schemas, 0, sizeof(Schema*)* out.count); + for (SizeType i = 0; i < out.count; i++) + schemaDocument.CreateSchema(&out.schemas[i], q.Append(i, allocator_), (*v)[i], document, id_); + out.begin = validatorCount_; + validatorCount_ += out.count; + } + } + } + +#if RAPIDJSON_SCHEMA_USE_INTERNALREGEX + template + RegexType* CreatePattern(const ValueType& value, SchemaDocumentType* sd, const PointerType& p) { + if (value.IsString()) { + RegexType* r = new (allocator_->Malloc(sizeof(RegexType))) RegexType(value.GetString(), allocator_); + if (!r->IsValid()) { + sd->SchemaErrorValue(kSchemaErrorRegexInvalid, p, value.GetString(), value.GetStringLength()); + r->~RegexType(); + AllocatorType::Free(r); + r = 0; + } + return r; + } + return 0; + } + + static bool IsPatternMatch(const RegexType* pattern, const Ch *str, SizeType) { + GenericRegexSearch rs(*pattern); + return rs.Search(str); + } +#elif RAPIDJSON_SCHEMA_USE_STDREGEX + template + RegexType* CreatePattern(const ValueType& value, SchemaDocumentType* sd, const PointerType& p) { + if (value.IsString()) { + RegexType *r = static_cast(allocator_->Malloc(sizeof(RegexType))); + try { + return new (r) RegexType(value.GetString(), std::size_t(value.GetStringLength()), std::regex_constants::ECMAScript); + } + catch (const std::regex_error& e) { + sd->SchemaErrorValue(kSchemaErrorRegexInvalid, p, value.GetString(), value.GetStringLength()); + AllocatorType::Free(r); + } + } + return 0; + } + + static bool IsPatternMatch(const RegexType* pattern, const Ch *str, SizeType length) { + std::match_results r; + return std::regex_search(str, str + length, r, *pattern); + } +#else + template + RegexType* CreatePattern(const ValueType&) { + return 0; + } + + static bool IsPatternMatch(const RegexType*, const Ch *, SizeType) { return true; } +#endif // RAPIDJSON_SCHEMA_USE_STDREGEX + + void AddType(const ValueType& type) { + if (type == GetNullString() ) type_ |= 1 << kNullSchemaType; + else if (type == GetBooleanString()) type_ |= 1 << kBooleanSchemaType; + else if (type == GetObjectString() ) type_ |= 1 << kObjectSchemaType; + else if (type == GetArrayString() ) type_ |= 1 << kArraySchemaType; + else if (type == GetStringString() ) type_ |= 1 << kStringSchemaType; + else if (type == GetIntegerString()) type_ |= 1 << kIntegerSchemaType; + else if (type == GetNumberString() ) type_ |= (1 << kNumberSchemaType) | (1 << kIntegerSchemaType); + } + + // Creates parallel validators for allOf, anyOf, oneOf, not and schema dependencies, if required. + // Also creates a hasher for enums and array uniqueness, if required. + // Also a useful place to add type-independent error checks. + bool CreateParallelValidator(Context& context) const { + if (enum_ || context.arrayUniqueness) + context.hasher = context.factory.CreateHasher(); + + if (validatorCount_) { + RAPIDJSON_ASSERT(context.validators == 0); + context.validators = static_cast(context.factory.MallocState(sizeof(ISchemaValidator*) * validatorCount_)); + std::memset(context.validators, 0, sizeof(ISchemaValidator*) * validatorCount_); + context.validatorCount = validatorCount_; + + // Always return after first failure for these sub-validators + if (allOf_.schemas) + CreateSchemaValidators(context, allOf_, false); + + if (anyOf_.schemas) + CreateSchemaValidators(context, anyOf_, false); + + if (oneOf_.schemas) + CreateSchemaValidators(context, oneOf_, false); + + if (not_) + context.validators[notValidatorIndex_] = context.factory.CreateSchemaValidator(*not_, false); + + if (hasSchemaDependencies_) { + for (SizeType i = 0; i < propertyCount_; i++) + if (properties_[i].dependenciesSchema) + context.validators[properties_[i].dependenciesValidatorIndex] = context.factory.CreateSchemaValidator(*properties_[i].dependenciesSchema, false); + } + } + + // Add any other type-independent checks here + if (readOnly_ && (context.flags & kValidateWriteFlag)) { + context.error_handler.DisallowedWhenWriting(); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorReadOnly); + } + if (writeOnly_ && (context.flags & kValidateReadFlag)) { + context.error_handler.DisallowedWhenReading(); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorWriteOnly); + } + + return true; + } + + void CreateSchemaValidators(Context& context, const SchemaArray& schemas, const bool inheritContinueOnErrors) const { + for (SizeType i = 0; i < schemas.count; i++) + context.validators[schemas.begin + i] = context.factory.CreateSchemaValidator(*schemas.schemas[i], inheritContinueOnErrors); + } + + // O(n) + bool FindPropertyIndex(const ValueType& name, SizeType* outIndex) const { + SizeType len = name.GetStringLength(); + const Ch* str = name.GetString(); + for (SizeType index = 0; index < propertyCount_; index++) + if (properties_[index].name.GetStringLength() == len && + (std::memcmp(properties_[index].name.GetString(), str, sizeof(Ch) * len) == 0)) + { + *outIndex = index; + return true; + } + return false; + } + + bool CheckBool(Context& context, bool) const { + if (!(type_ & (1 << kBooleanSchemaType))) { + DisallowedType(context, GetBooleanString()); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorType); + } + return true; + } + + bool CheckInt(Context& context, int64_t i) const { + if (!(type_ & ((1 << kIntegerSchemaType) | (1 << kNumberSchemaType)))) { + DisallowedType(context, GetIntegerString()); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorType); + } + + if (!minimum_.IsNull()) { + if (minimum_.IsInt64()) { + if (exclusiveMinimum_ ? i <= minimum_.GetInt64() : i < minimum_.GetInt64()) { + context.error_handler.BelowMinimum(i, minimum_, exclusiveMinimum_); + RAPIDJSON_INVALID_KEYWORD_RETURN(exclusiveMinimum_ ? kValidateErrorExclusiveMinimum : kValidateErrorMinimum); + } + } + else if (minimum_.IsUint64()) { + context.error_handler.BelowMinimum(i, minimum_, exclusiveMinimum_); + RAPIDJSON_INVALID_KEYWORD_RETURN(exclusiveMinimum_ ? kValidateErrorExclusiveMinimum : kValidateErrorMinimum); // i <= max(int64_t) < minimum.GetUint64() + } + else if (!CheckDoubleMinimum(context, static_cast(i))) + return false; + } + + if (!maximum_.IsNull()) { + if (maximum_.IsInt64()) { + if (exclusiveMaximum_ ? i >= maximum_.GetInt64() : i > maximum_.GetInt64()) { + context.error_handler.AboveMaximum(i, maximum_, exclusiveMaximum_); + RAPIDJSON_INVALID_KEYWORD_RETURN(exclusiveMaximum_ ? kValidateErrorExclusiveMaximum : kValidateErrorMaximum); + } + } + else if (maximum_.IsUint64()) { } + /* do nothing */ // i <= max(int64_t) < maximum_.GetUint64() + else if (!CheckDoubleMaximum(context, static_cast(i))) + return false; + } + + if (!multipleOf_.IsNull()) { + if (multipleOf_.IsUint64()) { + if (static_cast(i >= 0 ? i : -i) % multipleOf_.GetUint64() != 0) { + context.error_handler.NotMultipleOf(i, multipleOf_); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorMultipleOf); + } + } + else if (!CheckDoubleMultipleOf(context, static_cast(i))) + return false; + } + + return true; + } + + bool CheckUint(Context& context, uint64_t i) const { + if (!(type_ & ((1 << kIntegerSchemaType) | (1 << kNumberSchemaType)))) { + DisallowedType(context, GetIntegerString()); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorType); + } + + if (!minimum_.IsNull()) { + if (minimum_.IsUint64()) { + if (exclusiveMinimum_ ? i <= minimum_.GetUint64() : i < minimum_.GetUint64()) { + context.error_handler.BelowMinimum(i, minimum_, exclusiveMinimum_); + RAPIDJSON_INVALID_KEYWORD_RETURN(exclusiveMinimum_ ? kValidateErrorExclusiveMinimum : kValidateErrorMinimum); + } + } + else if (minimum_.IsInt64()) + /* do nothing */; // i >= 0 > minimum.Getint64() + else if (!CheckDoubleMinimum(context, static_cast(i))) + return false; + } + + if (!maximum_.IsNull()) { + if (maximum_.IsUint64()) { + if (exclusiveMaximum_ ? i >= maximum_.GetUint64() : i > maximum_.GetUint64()) { + context.error_handler.AboveMaximum(i, maximum_, exclusiveMaximum_); + RAPIDJSON_INVALID_KEYWORD_RETURN(exclusiveMaximum_ ? kValidateErrorExclusiveMaximum : kValidateErrorMaximum); + } + } + else if (maximum_.IsInt64()) { + context.error_handler.AboveMaximum(i, maximum_, exclusiveMaximum_); + RAPIDJSON_INVALID_KEYWORD_RETURN(exclusiveMaximum_ ? kValidateErrorExclusiveMaximum : kValidateErrorMaximum); // i >= 0 > maximum_ + } + else if (!CheckDoubleMaximum(context, static_cast(i))) + return false; + } + + if (!multipleOf_.IsNull()) { + if (multipleOf_.IsUint64()) { + if (i % multipleOf_.GetUint64() != 0) { + context.error_handler.NotMultipleOf(i, multipleOf_); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorMultipleOf); + } + } + else if (!CheckDoubleMultipleOf(context, static_cast(i))) + return false; + } + + return true; + } + + bool CheckDoubleMinimum(Context& context, double d) const { + if (exclusiveMinimum_ ? d <= minimum_.GetDouble() : d < minimum_.GetDouble()) { + context.error_handler.BelowMinimum(d, minimum_, exclusiveMinimum_); + RAPIDJSON_INVALID_KEYWORD_RETURN(exclusiveMinimum_ ? kValidateErrorExclusiveMinimum : kValidateErrorMinimum); + } + return true; + } + + bool CheckDoubleMaximum(Context& context, double d) const { + if (exclusiveMaximum_ ? d >= maximum_.GetDouble() : d > maximum_.GetDouble()) { + context.error_handler.AboveMaximum(d, maximum_, exclusiveMaximum_); + RAPIDJSON_INVALID_KEYWORD_RETURN(exclusiveMaximum_ ? kValidateErrorExclusiveMaximum : kValidateErrorMaximum); + } + return true; + } + + bool CheckDoubleMultipleOf(Context& context, double d) const { + double a = std::abs(d), b = std::abs(multipleOf_.GetDouble()); + double q = a / b; + double qRounded = std::floor(q + 0.5); + double scaledEpsilon = (q + qRounded) * std::numeric_limits::epsilon(); + double difference = std::abs(qRounded - q); + bool isMultiple = difference <= scaledEpsilon || difference < (std::numeric_limits::min)(); + if (!isMultiple) { + context.error_handler.NotMultipleOf(d, multipleOf_); + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorMultipleOf); + } + return true; + } + + void DisallowedType(Context& context, const ValueType& actualType) const { + ErrorHandler& eh = context.error_handler; + eh.StartDisallowedType(); + + if (type_ & (1 << kNullSchemaType)) eh.AddExpectedType(GetNullString()); + if (type_ & (1 << kBooleanSchemaType)) eh.AddExpectedType(GetBooleanString()); + if (type_ & (1 << kObjectSchemaType)) eh.AddExpectedType(GetObjectString()); + if (type_ & (1 << kArraySchemaType)) eh.AddExpectedType(GetArrayString()); + if (type_ & (1 << kStringSchemaType)) eh.AddExpectedType(GetStringString()); + + if (type_ & (1 << kNumberSchemaType)) eh.AddExpectedType(GetNumberString()); + else if (type_ & (1 << kIntegerSchemaType)) eh.AddExpectedType(GetIntegerString()); + + eh.EndDisallowedType(actualType); + } + + struct Property { + Property() : schema(), dependenciesSchema(), dependenciesValidatorIndex(), dependencies(), required(false) {} + ~Property() { AllocatorType::Free(dependencies); } + SValue name; + const SchemaType* schema; + const SchemaType* dependenciesSchema; + SizeType dependenciesValidatorIndex; + bool* dependencies; + bool required; + }; + + struct PatternProperty { + PatternProperty() : schema(), pattern() {} + ~PatternProperty() { + if (pattern) { + pattern->~RegexType(); + AllocatorType::Free(pattern); + } + } + const SchemaType* schema; + RegexType* pattern; + }; + + AllocatorType* allocator_; + SValue uri_; + UriType id_; + Specification spec_; + PointerType pointer_; + const SchemaType* typeless_; + uint64_t* enum_; + SizeType enumCount_; + SchemaArray allOf_; + SchemaArray anyOf_; + SchemaArray oneOf_; + const SchemaType* not_; + unsigned type_; // bitmask of kSchemaType + SizeType validatorCount_; + SizeType notValidatorIndex_; + + Property* properties_; + const SchemaType* additionalPropertiesSchema_; + PatternProperty* patternProperties_; + SizeType patternPropertyCount_; + SizeType propertyCount_; + SizeType minProperties_; + SizeType maxProperties_; + bool additionalProperties_; + bool hasDependencies_; + bool hasRequired_; + bool hasSchemaDependencies_; + + const SchemaType* additionalItemsSchema_; + const SchemaType* itemsList_; + const SchemaType** itemsTuple_; + SizeType itemsTupleCount_; + SizeType minItems_; + SizeType maxItems_; + bool additionalItems_; + bool uniqueItems_; + + RegexType* pattern_; + SizeType minLength_; + SizeType maxLength_; + + SValue minimum_; + SValue maximum_; + SValue multipleOf_; + bool exclusiveMinimum_; + bool exclusiveMaximum_; + + SizeType defaultValueLength_; + + bool readOnly_; + bool writeOnly_; + bool nullable_; +}; + +template +struct TokenHelper { + RAPIDJSON_FORCEINLINE static void AppendIndexToken(Stack& documentStack, SizeType index) { + *documentStack.template Push() = '/'; + char buffer[21]; + size_t length = static_cast((sizeof(SizeType) == 4 ? u32toa(index, buffer) : u64toa(index, buffer)) - buffer); + for (size_t i = 0; i < length; i++) + *documentStack.template Push() = static_cast(buffer[i]); + } +}; + +// Partial specialized version for char to prevent buffer copying. +template +struct TokenHelper { + RAPIDJSON_FORCEINLINE static void AppendIndexToken(Stack& documentStack, SizeType index) { + RAPIDJSON_IF_CONSTEXPR (sizeof(SizeType) == 4) { + char *buffer = documentStack.template Push(1 + 10); // '/' + uint + *buffer++ = '/'; + const char* end = internal::u32toa(index, buffer); + documentStack.template Pop(static_cast(10 - (end - buffer))); + } + else { + char *buffer = documentStack.template Push(1 + 20); // '/' + uint64 + *buffer++ = '/'; + const char* end = internal::u64toa(index, buffer); + documentStack.template Pop(static_cast(20 - (end - buffer))); + } + } +}; + +} // namespace internal + +/////////////////////////////////////////////////////////////////////////////// +// IGenericRemoteSchemaDocumentProvider + +template +class IGenericRemoteSchemaDocumentProvider { +public: + typedef typename SchemaDocumentType::Ch Ch; + typedef typename SchemaDocumentType::ValueType ValueType; + typedef typename SchemaDocumentType::AllocatorType AllocatorType; + + virtual ~IGenericRemoteSchemaDocumentProvider() {} + virtual const SchemaDocumentType* GetRemoteDocument(const Ch* uri, SizeType length) = 0; + virtual const SchemaDocumentType* GetRemoteDocument(const GenericUri uri, Specification& spec) { + // Default implementation just calls through for compatibility + // Following line suppresses unused parameter warning + (void)spec; + // printf("GetRemoteDocument: %d %d\n", spec.draft, spec.oapi); + return GetRemoteDocument(uri.GetBaseString(), uri.GetBaseStringLength()); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// GenericSchemaDocument + +//! JSON schema document. +/*! + A JSON schema document is a compiled version of a JSON schema. + It is basically a tree of internal::Schema. + + \note This is an immutable class (i.e. its instance cannot be modified after construction). + \tparam ValueT Type of JSON value (e.g. \c Value ), which also determine the encoding. + \tparam Allocator Allocator type for allocating memory of this document. +*/ +template +class GenericSchemaDocument { +public: + typedef ValueT ValueType; + typedef IGenericRemoteSchemaDocumentProvider IRemoteSchemaDocumentProviderType; + typedef Allocator AllocatorType; + typedef typename ValueType::EncodingType EncodingType; + typedef typename EncodingType::Ch Ch; + typedef internal::Schema SchemaType; + typedef GenericPointer PointerType; + typedef GenericValue GValue; + typedef GenericUri UriType; + typedef GenericStringRef StringRefType; + friend class internal::Schema; + template + friend class GenericSchemaValidator; + + //! Constructor. + /*! + Compile a JSON document into schema document. + + \param document A JSON document as source. + \param uri The base URI of this schema document for purposes of violation reporting. + \param uriLength Length of \c name, in code points. + \param remoteProvider An optional remote schema document provider for resolving remote reference. Can be null. + \param allocator An optional allocator instance for allocating memory. Can be null. + \param pointer An optional JSON pointer to the start of the schema document + \param spec Optional schema draft or OpenAPI version. Used if no specification in document. Defaults to draft-04. + */ + explicit GenericSchemaDocument(const ValueType& document, const Ch* uri = 0, SizeType uriLength = 0, + IRemoteSchemaDocumentProviderType* remoteProvider = 0, Allocator* allocator = 0, + const PointerType& pointer = PointerType(), // PR #1393 + const Specification& spec = Specification(kDraft04)) : + remoteProvider_(remoteProvider), + allocator_(allocator), + ownAllocator_(), + root_(), + typeless_(), + schemaMap_(allocator, kInitialSchemaMapSize), + schemaRef_(allocator, kInitialSchemaRefSize), + spec_(spec), + error_(kObjectType), + currentError_() + { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaDocument::GenericSchemaDocument"); + if (!allocator_) + ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)(); + + Ch noUri[1] = {0}; + uri_.SetString(uri ? uri : noUri, uriLength, *allocator_); + docId_ = UriType(uri_, allocator_); + + typeless_ = static_cast(allocator_->Malloc(sizeof(SchemaType))); + new (typeless_) SchemaType(this, PointerType(), ValueType(kObjectType).Move(), ValueType(kObjectType).Move(), allocator_, docId_); + + // Establish the schema draft or open api version. + // We only ever look for '$schema' or 'swagger' or 'openapi' at the root of the document. + SetSchemaSpecification(document); + + // Generate root schema, it will call CreateSchema() to create sub-schemas, + // And call HandleRefSchema() if there are $ref. + // PR #1393 use input pointer if supplied + root_ = typeless_; + if (pointer.GetTokenCount() == 0) { + CreateSchemaRecursive(&root_, pointer, document, document, docId_); + } + else if (const ValueType* v = pointer.Get(document)) { + CreateSchema(&root_, pointer, *v, document, docId_); + } + else { + GenericStringBuffer sb; + pointer.StringifyUriFragment(sb); + SchemaErrorValue(kSchemaErrorStartUnknown, PointerType(), sb.GetString(), static_cast(sb.GetSize() / sizeof(Ch))); + } + + RAPIDJSON_ASSERT(root_ != 0); + + schemaRef_.ShrinkToFit(); // Deallocate all memory for ref + } + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + //! Move constructor in C++11 + GenericSchemaDocument(GenericSchemaDocument&& rhs) RAPIDJSON_NOEXCEPT : + remoteProvider_(rhs.remoteProvider_), + allocator_(rhs.allocator_), + ownAllocator_(rhs.ownAllocator_), + root_(rhs.root_), + typeless_(rhs.typeless_), + schemaMap_(std::move(rhs.schemaMap_)), + schemaRef_(std::move(rhs.schemaRef_)), + uri_(std::move(rhs.uri_)), + docId_(std::move(rhs.docId_)), + spec_(rhs.spec_), + error_(std::move(rhs.error_)), + currentError_(std::move(rhs.currentError_)) + { + rhs.remoteProvider_ = 0; + rhs.allocator_ = 0; + rhs.ownAllocator_ = 0; + rhs.typeless_ = 0; + } +#endif + + //! Destructor + ~GenericSchemaDocument() { + while (!schemaMap_.Empty()) + schemaMap_.template Pop(1)->~SchemaEntry(); + + if (typeless_) { + typeless_->~SchemaType(); + Allocator::Free(typeless_); + } + + // these may contain some allocator data so clear before deleting ownAllocator_ + uri_.SetNull(); + error_.SetNull(); + currentError_.SetNull(); + + RAPIDJSON_DELETE(ownAllocator_); + } + + const GValue& GetURI() const { return uri_; } + + const Specification& GetSpecification() const { return spec_; } + bool IsSupportedSpecification() const { return spec_.IsSupported(); } + + //! Static method to get the specification of any schema document + // Returns kDraftNone if document is silent + static const Specification GetSpecification(const ValueType& document) { + SchemaDraft draft = GetSchemaDraft(document); + if (draft != kDraftNone) + return Specification(draft); + else { + OpenApiVersion oapi = GetOpenApiVersion(document); + if (oapi != kVersionNone) + return Specification(oapi); + } + return Specification(kDraftNone); + } + + //! Get the root schema. + const SchemaType& GetRoot() const { return *root_; } + + //! Gets the error object. + GValue& GetError() { return error_; } + const GValue& GetError() const { return error_; } + + static const StringRefType& GetSchemaErrorKeyword(SchemaErrorCode schemaErrorCode) { + switch (schemaErrorCode) { + case kSchemaErrorStartUnknown: return GetStartUnknownString(); + case kSchemaErrorRefPlainName: return GetRefPlainNameString(); + case kSchemaErrorRefInvalid: return GetRefInvalidString(); + case kSchemaErrorRefPointerInvalid: return GetRefPointerInvalidString(); + case kSchemaErrorRefUnknown: return GetRefUnknownString(); + case kSchemaErrorRefCyclical: return GetRefCyclicalString(); + case kSchemaErrorRefNoRemoteProvider: return GetRefNoRemoteProviderString(); + case kSchemaErrorRefNoRemoteSchema: return GetRefNoRemoteSchemaString(); + case kSchemaErrorRegexInvalid: return GetRegexInvalidString(); + case kSchemaErrorSpecUnknown: return GetSpecUnknownString(); + case kSchemaErrorSpecUnsupported: return GetSpecUnsupportedString(); + case kSchemaErrorSpecIllegal: return GetSpecIllegalString(); + case kSchemaErrorReadOnlyAndWriteOnly: return GetReadOnlyAndWriteOnlyString(); + default: return GetNullString(); + } + } + + //! Default error method + void SchemaError(const SchemaErrorCode code, const PointerType& location) { + currentError_ = GValue(kObjectType); + AddCurrentError(code, location); + } + + //! Method for error with single string value insert + void SchemaErrorValue(const SchemaErrorCode code, const PointerType& location, const Ch* value, SizeType length) { + currentError_ = GValue(kObjectType); + currentError_.AddMember(GetValueString(), GValue(value, length, *allocator_).Move(), *allocator_); + AddCurrentError(code, location); + } + + //! Method for error with invalid pointer + void SchemaErrorPointer(const SchemaErrorCode code, const PointerType& location, const Ch* value, SizeType length, const PointerType& pointer) { + currentError_ = GValue(kObjectType); + currentError_.AddMember(GetValueString(), GValue(value, length, *allocator_).Move(), *allocator_); + currentError_.AddMember(GetOffsetString(), static_cast(pointer.GetParseErrorOffset() / sizeof(Ch)), *allocator_); + AddCurrentError(code, location); + } + + private: + //! Prohibit copying + GenericSchemaDocument(const GenericSchemaDocument&); + //! Prohibit assignment + GenericSchemaDocument& operator=(const GenericSchemaDocument&); + + typedef const PointerType* SchemaRefPtr; // PR #1393 + + struct SchemaEntry { + SchemaEntry(const PointerType& p, SchemaType* s, bool o, Allocator* allocator) : pointer(p, allocator), schema(s), owned(o) {} + ~SchemaEntry() { + if (owned) { + schema->~SchemaType(); + Allocator::Free(schema); + } + } + PointerType pointer; + SchemaType* schema; + bool owned; + }; + + void AddErrorInstanceLocation(GValue& result, const PointerType& location) { + GenericStringBuffer sb; + location.StringifyUriFragment(sb); + GValue instanceRef(sb.GetString(), static_cast(sb.GetSize() / sizeof(Ch)), *allocator_); + result.AddMember(GetInstanceRefString(), instanceRef, *allocator_); + } + + void AddError(GValue& keyword, GValue& error) { + typename GValue::MemberIterator member = error_.FindMember(keyword); + if (member == error_.MemberEnd()) + error_.AddMember(keyword, error, *allocator_); + else { + if (member->value.IsObject()) { + GValue errors(kArrayType); + errors.PushBack(member->value, *allocator_); + member->value = errors; + } + member->value.PushBack(error, *allocator_); + } + } + + void AddCurrentError(const SchemaErrorCode code, const PointerType& location) { + RAPIDJSON_SCHEMA_PRINT(InvalidKeyword, GetSchemaErrorKeyword(code)); + currentError_.AddMember(GetErrorCodeString(), code, *allocator_); + AddErrorInstanceLocation(currentError_, location); + AddError(GValue(GetSchemaErrorKeyword(code)).Move(), currentError_); + } + +#define RAPIDJSON_STRING_(name, ...) \ + static const StringRefType& Get##name##String() {\ + static const Ch s[] = { __VA_ARGS__, '\0' };\ + static const StringRefType v(s, static_cast(sizeof(s) / sizeof(Ch) - 1)); \ + return v;\ + } + + RAPIDJSON_STRING_(InstanceRef, 'i', 'n', 's', 't', 'a', 'n', 'c', 'e', 'R', 'e', 'f') + RAPIDJSON_STRING_(ErrorCode, 'e', 'r', 'r', 'o', 'r', 'C', 'o', 'd', 'e') + RAPIDJSON_STRING_(Value, 'v', 'a', 'l', 'u', 'e') + RAPIDJSON_STRING_(Offset, 'o', 'f', 'f', 's', 'e', 't') + + RAPIDJSON_STRING_(Null, 'n', 'u', 'l', 'l') + RAPIDJSON_STRING_(SpecUnknown, 'S', 'p', 'e', 'c', 'U', 'n', 'k', 'n', 'o', 'w', 'n') + RAPIDJSON_STRING_(SpecUnsupported, 'S', 'p', 'e', 'c', 'U', 'n', 's', 'u', 'p', 'p', 'o', 'r', 't', 'e', 'd') + RAPIDJSON_STRING_(SpecIllegal, 'S', 'p', 'e', 'c', 'I', 'l', 'l', 'e', 'g', 'a', 'l') + RAPIDJSON_STRING_(StartUnknown, 'S', 't', 'a', 'r', 't', 'U', 'n', 'k', 'n', 'o', 'w', 'n') + RAPIDJSON_STRING_(RefPlainName, 'R', 'e', 'f', 'P', 'l', 'a', 'i', 'n', 'N', 'a', 'm', 'e') + RAPIDJSON_STRING_(RefInvalid, 'R', 'e', 'f', 'I', 'n', 'v', 'a', 'l', 'i', 'd') + RAPIDJSON_STRING_(RefPointerInvalid, 'R', 'e', 'f', 'P', 'o', 'i', 'n', 't', 'e', 'r', 'I', 'n', 'v', 'a', 'l', 'i', 'd') + RAPIDJSON_STRING_(RefUnknown, 'R', 'e', 'f', 'U', 'n', 'k', 'n', 'o', 'w', 'n') + RAPIDJSON_STRING_(RefCyclical, 'R', 'e', 'f', 'C', 'y', 'c', 'l', 'i', 'c', 'a', 'l') + RAPIDJSON_STRING_(RefNoRemoteProvider, 'R', 'e', 'f', 'N', 'o', 'R', 'e', 'm', 'o', 't', 'e', 'P', 'r', 'o', 'v', 'i', 'd', 'e', 'r') + RAPIDJSON_STRING_(RefNoRemoteSchema, 'R', 'e', 'f', 'N', 'o', 'R', 'e', 'm', 'o', 't', 'e', 'S', 'c', 'h', 'e', 'm', 'a') + RAPIDJSON_STRING_(ReadOnlyAndWriteOnly, 'R', 'e', 'a', 'd', 'O', 'n', 'l', 'y', 'A', 'n', 'd', 'W', 'r', 'i', 't', 'e', 'O', 'n', 'l', 'y') + RAPIDJSON_STRING_(RegexInvalid, 'R', 'e', 'g', 'e', 'x', 'I', 'n', 'v', 'a', 'l', 'i', 'd') + +#undef RAPIDJSON_STRING_ + + // Static method to get schema draft of any schema document + static SchemaDraft GetSchemaDraft(const ValueType& document) { + static const Ch kDraft03String[] = { 'h', 't', 't', 'p', ':', '/', '/', 'j', 's', 'o', 'n', '-', 's', 'c', 'h', 'e', 'm', 'a', '.', 'o', 'r', 'g', '/', 'd', 'r', 'a', 'f', 't', '-', '0', '3', '/', 's', 'c', 'h', 'e', 'm', 'a', '#', '\0' }; + static const Ch kDraft04String[] = { 'h', 't', 't', 'p', ':', '/', '/', 'j', 's', 'o', 'n', '-', 's', 'c', 'h', 'e', 'm', 'a', '.', 'o', 'r', 'g', '/', 'd', 'r', 'a', 'f', 't', '-', '0', '4', '/', 's', 'c', 'h', 'e', 'm', 'a', '#', '\0' }; + static const Ch kDraft05String[] = { 'h', 't', 't', 'p', ':', '/', '/', 'j', 's', 'o', 'n', '-', 's', 'c', 'h', 'e', 'm', 'a', '.', 'o', 'r', 'g', '/', 'd', 'r', 'a', 'f', 't', '-', '0', '5', '/', 's', 'c', 'h', 'e', 'm', 'a', '#', '\0' }; + static const Ch kDraft06String[] = { 'h', 't', 't', 'p', ':', '/', '/', 'j', 's', 'o', 'n', '-', 's', 'c', 'h', 'e', 'm', 'a', '.', 'o', 'r', 'g', '/', 'd', 'r', 'a', 'f', 't', '-', '0', '6', '/', 's', 'c', 'h', 'e', 'm', 'a', '#', '\0' }; + static const Ch kDraft07String[] = { 'h', 't', 't', 'p', ':', '/', '/', 'j', 's', 'o', 'n', '-', 's', 'c', 'h', 'e', 'm', 'a', '.', 'o', 'r', 'g', '/', 'd', 'r', 'a', 'f', 't', '-', '0', '7', '/', 's', 'c', 'h', 'e', 'm', 'a', '#', '\0' }; + static const Ch kDraft2019_09String[] = { 'h', 't', 't', 'p', 's', ':', '/', '/', 'j', 's', 'o', 'n', '-', 's', 'c', 'h', 'e', 'm', 'a', '.', 'o', 'r', 'g', '/', 'd', 'r', 'a', 'f', 't', '/', '2', '0', '1', '9', '-', '0', '9', '/', 's', 'c', 'h', 'e', 'm', 'a', '\0' }; + static const Ch kDraft2020_12String[] = { 'h', 't', 't', 'p', 's', ':', '/', '/', 'j', 's', 'o', 'n', '-', 's', 'c', 'h', 'e', 'm', 'a', '.', 'o', 'r', 'g', '/', 'd', 'r', 'a', 'f', 't', '/', '2', '0', '2', '0', '-', '1', '2', '/', 's', 'c', 'h', 'e', 'm', 'a', '\0' }; + + if (!document.IsObject()) { + return kDraftNone; + } + + // Get the schema draft from the $schema keyword at the supplied location + typename ValueType::ConstMemberIterator itr = document.FindMember(SchemaType::GetSchemaString()); + if (itr != document.MemberEnd()) { + if (!itr->value.IsString()) return kDraftUnknown; + const UriType draftUri(itr->value); + // Check base uri for match + if (draftUri.Match(UriType(kDraft04String), false)) return kDraft04; + if (draftUri.Match(UriType(kDraft05String), false)) return kDraft05; + if (draftUri.Match(UriType(kDraft06String), false)) return kDraft06; + if (draftUri.Match(UriType(kDraft07String), false)) return kDraft07; + if (draftUri.Match(UriType(kDraft03String), false)) return kDraft03; + if (draftUri.Match(UriType(kDraft2019_09String), false)) return kDraft2019_09; + if (draftUri.Match(UriType(kDraft2020_12String), false)) return kDraft2020_12; + return kDraftUnknown; + } + // $schema not found + return kDraftNone; + } + + + // Get open api version of any schema document + static OpenApiVersion GetOpenApiVersion(const ValueType& document) { + static const Ch kVersion20String[] = { '2', '.', '0', '\0' }; + static const Ch kVersion30String[] = { '3', '.', '0', '.', '\0' }; // ignore patch level + static const Ch kVersion31String[] = { '3', '.', '1', '.', '\0' }; // ignore patch level + static SizeType len = internal::StrLen(kVersion30String); + + if (!document.IsObject()) { + return kVersionNone; + } + + // Get the open api version from the swagger / openapi keyword at the supplied location + typename ValueType::ConstMemberIterator itr = document.FindMember(SchemaType::GetSwaggerString()); + if (itr == document.MemberEnd()) itr = document.FindMember(SchemaType::GetOpenApiString()); + if (itr != document.MemberEnd()) { + if (!itr->value.IsString()) return kVersionUnknown; + const ValueType kVersion20Value(kVersion20String); + if (kVersion20Value == itr->value) return kVersion20; // must match 2.0 exactly + const ValueType kVersion30Value(kVersion30String); + if (itr->value.GetStringLength() > len && kVersion30Value == ValueType(itr->value.GetString(), len)) return kVersion30; // must match 3.0.x + const ValueType kVersion31Value(kVersion31String); + if (itr->value.GetStringLength() > len && kVersion31Value == ValueType(itr->value.GetString(), len)) return kVersion31; // must match 3.1.x + return kVersionUnknown; + } + // swagger or openapi not found + return kVersionNone; + } + + // Get the draft of the schema or the open api version (which implies the draft). + // Report an error if schema draft or open api version not supported or not recognized, or both in document, and carry on. + void SetSchemaSpecification(const ValueType& document) { + // Look for '$schema', 'swagger' or 'openapi' keyword at document root + SchemaDraft docDraft = GetSchemaDraft(document); + OpenApiVersion docOapi = GetOpenApiVersion(document); + // Error if both in document + if (docDraft != kDraftNone && docOapi != kVersionNone) + SchemaError(kSchemaErrorSpecIllegal, PointerType()); + // Use document draft or open api version if present or use spec from constructor + if (docDraft != kDraftNone) + spec_ = Specification(docDraft); + else if (docOapi != kVersionNone) + spec_ = Specification(docOapi); + // Error if draft or version unknown + if (spec_.draft == kDraftUnknown || spec_.oapi == kVersionUnknown) + SchemaError(kSchemaErrorSpecUnknown, PointerType()); + else if (!spec_.IsSupported()) + SchemaError(kSchemaErrorSpecUnsupported, PointerType()); + } + + // Changed by PR #1393 + void CreateSchemaRecursive(const SchemaType** schema, const PointerType& pointer, const ValueType& v, const ValueType& document, const UriType& id) { + if (v.GetType() == kObjectType) { + UriType newid = UriType(CreateSchema(schema, pointer, v, document, id), allocator_); + + for (typename ValueType::ConstMemberIterator itr = v.MemberBegin(); itr != v.MemberEnd(); ++itr) + CreateSchemaRecursive(0, pointer.Append(itr->name, allocator_), itr->value, document, newid); + } + else if (v.GetType() == kArrayType) + for (SizeType i = 0; i < v.Size(); i++) + CreateSchemaRecursive(0, pointer.Append(i, allocator_), v[i], document, id); + } + + // Changed by PR #1393 + const UriType& CreateSchema(const SchemaType** schema, const PointerType& pointer, const ValueType& v, const ValueType& document, const UriType& id) { + RAPIDJSON_ASSERT(pointer.IsValid()); + GenericStringBuffer sb; + pointer.StringifyUriFragment(sb); + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaDocument::CreateSchema", sb.GetString(), id.GetString()); + if (v.IsObject()) { + if (const SchemaType* sc = GetSchema(pointer)) { + if (schema) + *schema = sc; + AddSchemaRefs(const_cast(sc)); + } + else if (!HandleRefSchema(pointer, schema, v, document, id)) { + // The new schema constructor adds itself and its $ref(s) to schemaMap_ + SchemaType* s = new (allocator_->Malloc(sizeof(SchemaType))) SchemaType(this, pointer, v, document, allocator_, id); + if (schema) + *schema = s; + return s->GetId(); + } + } + else { + if (schema) + *schema = typeless_; + AddSchemaRefs(typeless_); + } + return id; + } + + // Changed by PR #1393 + // TODO should this return a UriType& ? + bool HandleRefSchema(const PointerType& source, const SchemaType** schema, const ValueType& v, const ValueType& document, const UriType& id) { + typename ValueType::ConstMemberIterator itr = v.FindMember(SchemaType::GetRefString()); + if (itr == v.MemberEnd()) + return false; + + GenericStringBuffer sb; + source.StringifyUriFragment(sb); + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaDocument::HandleRefSchema", sb.GetString(), id.GetString()); + // Resolve the source pointer to the $ref'ed schema (finally) + new (schemaRef_.template Push()) SchemaRefPtr(&source); + + if (itr->value.IsString()) { + SizeType len = itr->value.GetStringLength(); + if (len == 0) + SchemaError(kSchemaErrorRefInvalid, source); + else { + // First resolve $ref against the in-scope id + UriType scopeId = UriType(id, allocator_); + UriType ref = UriType(itr->value, allocator_).Resolve(scopeId, allocator_); + RAPIDJSON_SCHEMA_PRINT(SchemaIds, id.GetString(), itr->value.GetString(), ref.GetString()); + // See if the resolved $ref minus the fragment matches a resolved id in this document + // Search from the root. Returns the subschema in the document and its absolute JSON pointer. + PointerType basePointer = PointerType(); + const ValueType *base = FindId(document, ref, basePointer, docId_, false); + if (!base) { + // Remote reference - call the remote document provider + if (!remoteProvider_) + SchemaError(kSchemaErrorRefNoRemoteProvider, source); + else { + if (const GenericSchemaDocument* remoteDocument = remoteProvider_->GetRemoteDocument(ref, spec_)) { + const Ch* s = ref.GetFragString(); + len = ref.GetFragStringLength(); + if (len <= 1 || s[1] == '/') { + // JSON pointer fragment, absolute in the remote schema + const PointerType pointer(s, len, allocator_); + if (!pointer.IsValid()) + SchemaErrorPointer(kSchemaErrorRefPointerInvalid, source, s, len, pointer); + else { + // Get the subschema + if (const SchemaType *sc = remoteDocument->GetSchema(pointer)) { + if (schema) + *schema = sc; + AddSchemaRefs(const_cast(sc)); + return true; + } else + SchemaErrorValue(kSchemaErrorRefUnknown, source, ref.GetString(), ref.GetStringLength()); + } + } else + // Plain name fragment, not allowed in remote schema + SchemaErrorValue(kSchemaErrorRefPlainName, source, s, len); + } else + SchemaErrorValue(kSchemaErrorRefNoRemoteSchema, source, ref.GetString(), ref.GetStringLength()); + } + } + else { // Local reference + const Ch* s = ref.GetFragString(); + len = ref.GetFragStringLength(); + if (len <= 1 || s[1] == '/') { + // JSON pointer fragment, relative to the resolved URI + const PointerType relPointer(s, len, allocator_); + if (!relPointer.IsValid()) + SchemaErrorPointer(kSchemaErrorRefPointerInvalid, source, s, len, relPointer); + else { + // Get the subschema + if (const ValueType *pv = relPointer.Get(*base)) { + // Now get the absolute JSON pointer by adding relative to base + PointerType pointer(basePointer, allocator_); + for (SizeType i = 0; i < relPointer.GetTokenCount(); i++) + pointer = pointer.Append(relPointer.GetTokens()[i], allocator_); + if (IsCyclicRef(pointer)) + SchemaErrorValue(kSchemaErrorRefCyclical, source, ref.GetString(), ref.GetStringLength()); + else { + // Call CreateSchema recursively, but first compute the in-scope id for the $ref target as we have jumped there + // TODO: cache pointer <-> id mapping + size_t unresolvedTokenIndex; + scopeId = pointer.GetUri(document, docId_, &unresolvedTokenIndex, allocator_); + CreateSchema(schema, pointer, *pv, document, scopeId); + return true; + } + } else + SchemaErrorValue(kSchemaErrorRefUnknown, source, ref.GetString(), ref.GetStringLength()); + } + } else { + // Plain name fragment, relative to the resolved URI + // Not supported in open api 2.0 and 3.0 + PointerType pointer(allocator_); + if (spec_.oapi == kVersion20 || spec_.oapi == kVersion30) + SchemaErrorValue(kSchemaErrorRefPlainName, source, s, len); + // See if the fragment matches an id in this document. + // Search from the base we just established. Returns the subschema in the document and its absolute JSON pointer. + else if (const ValueType *pv = FindId(*base, ref, pointer, UriType(ref.GetBaseString(), ref.GetBaseStringLength(), allocator_), true, basePointer)) { + if (IsCyclicRef(pointer)) + SchemaErrorValue(kSchemaErrorRefCyclical, source, ref.GetString(), ref.GetStringLength()); + else { + // Call CreateSchema recursively, but first compute the in-scope id for the $ref target as we have jumped there + // TODO: cache pointer <-> id mapping + size_t unresolvedTokenIndex; + scopeId = pointer.GetUri(document, docId_, &unresolvedTokenIndex, allocator_); + CreateSchema(schema, pointer, *pv, document, scopeId); + return true; + } + } else + SchemaErrorValue(kSchemaErrorRefUnknown, source, ref.GetString(), ref.GetStringLength()); + } + } + } + } + + // Invalid/Unknown $ref + if (schema) + *schema = typeless_; + AddSchemaRefs(typeless_); + return true; + } + + //! Find the first subschema with a resolved 'id' that matches the specified URI. + // If full specified use all URI else ignore fragment. + // If found, return a pointer to the subschema and its JSON pointer. + // TODO cache pointer <-> id mapping + ValueType* FindId(const ValueType& doc, const UriType& finduri, PointerType& resptr, const UriType& baseuri, bool full, const PointerType& here = PointerType()) const { + SizeType i = 0; + ValueType* resval = 0; + UriType tempuri = UriType(finduri, allocator_); + UriType localuri = UriType(baseuri, allocator_); + if (doc.GetType() == kObjectType) { + // Establish the base URI of this object + typename ValueType::ConstMemberIterator m = doc.FindMember(SchemaType::GetIdString()); + if (m != doc.MemberEnd() && m->value.GetType() == kStringType) { + localuri = UriType(m->value, allocator_).Resolve(baseuri, allocator_); + } + // See if it matches + if (localuri.Match(finduri, full)) { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaDocument::FindId (match)", full ? localuri.GetString() : localuri.GetBaseString()); + resval = const_cast(&doc); + resptr = here; + return resval; + } + // No match, continue looking + for (m = doc.MemberBegin(); m != doc.MemberEnd(); ++m) { + if (m->value.GetType() == kObjectType || m->value.GetType() == kArrayType) { + resval = FindId(m->value, finduri, resptr, localuri, full, here.Append(m->name.GetString(), m->name.GetStringLength(), allocator_)); + } + if (resval) break; + } + } else if (doc.GetType() == kArrayType) { + // Continue looking + for (typename ValueType::ConstValueIterator v = doc.Begin(); v != doc.End(); ++v) { + if (v->GetType() == kObjectType || v->GetType() == kArrayType) { + resval = FindId(*v, finduri, resptr, localuri, full, here.Append(i, allocator_)); + } + if (resval) break; + i++; + } + } + return resval; + } + + // Added by PR #1393 + void AddSchemaRefs(SchemaType* schema) { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaDocument::AddSchemaRefs"); + while (!schemaRef_.Empty()) { + SchemaRefPtr *ref = schemaRef_.template Pop(1); + SchemaEntry *entry = schemaMap_.template Push(); + new (entry) SchemaEntry(**ref, schema, false, allocator_); + } + } + + // Added by PR #1393 + bool IsCyclicRef(const PointerType& pointer) const { + for (const SchemaRefPtr* ref = schemaRef_.template Bottom(); ref != schemaRef_.template End(); ++ref) + if (pointer == **ref) + return true; + return false; + } + + const SchemaType* GetSchema(const PointerType& pointer) const { + for (const SchemaEntry* target = schemaMap_.template Bottom(); target != schemaMap_.template End(); ++target) + if (pointer == target->pointer) + return target->schema; + return 0; + } + + PointerType GetPointer(const SchemaType* schema) const { + for (const SchemaEntry* target = schemaMap_.template Bottom(); target != schemaMap_.template End(); ++target) + if (schema == target->schema) + return target->pointer; + return PointerType(); + } + + const SchemaType* GetTypeless() const { return typeless_; } + + static const size_t kInitialSchemaMapSize = 64; + static const size_t kInitialSchemaRefSize = 64; + + IRemoteSchemaDocumentProviderType* remoteProvider_; + Allocator *allocator_; + Allocator *ownAllocator_; + const SchemaType* root_; //!< Root schema. + SchemaType* typeless_; + internal::Stack schemaMap_; // Stores created Pointer -> Schemas + internal::Stack schemaRef_; // Stores Pointer(s) from $ref(s) until resolved + GValue uri_; // Schema document URI + UriType docId_; + Specification spec_; + GValue error_; + GValue currentError_; +}; + +//! GenericSchemaDocument using Value type. +typedef GenericSchemaDocument SchemaDocument; +//! IGenericRemoteSchemaDocumentProvider using SchemaDocument. +typedef IGenericRemoteSchemaDocumentProvider IRemoteSchemaDocumentProvider; + +/////////////////////////////////////////////////////////////////////////////// +// GenericSchemaValidator + +//! JSON Schema Validator. +/*! + A SAX style JSON schema validator. + It uses a \c GenericSchemaDocument to validate SAX events. + It delegates the incoming SAX events to an output handler. + The default output handler does nothing. + It can be reused multiple times by calling \c Reset(). + + \tparam SchemaDocumentType Type of schema document. + \tparam OutputHandler Type of output handler. Default handler does nothing. + \tparam StateAllocator Allocator for storing the internal validation states. +*/ +template < + typename SchemaDocumentType, + typename OutputHandler = BaseReaderHandler, + typename StateAllocator = CrtAllocator> +class GenericSchemaValidator : + public internal::ISchemaStateFactory, + public internal::ISchemaValidator, + public internal::IValidationErrorHandler { +public: + typedef typename SchemaDocumentType::SchemaType SchemaType; + typedef typename SchemaDocumentType::PointerType PointerType; + typedef typename SchemaType::EncodingType EncodingType; + typedef typename SchemaType::SValue SValue; + typedef typename EncodingType::Ch Ch; + typedef GenericStringRef StringRefType; + typedef GenericValue ValueType; + + //! Constructor without output handler. + /*! + \param schemaDocument The schema document to conform to. + \param allocator Optional allocator for storing internal validation states. + \param schemaStackCapacity Optional initial capacity of schema path stack. + \param documentStackCapacity Optional initial capacity of document path stack. + */ + GenericSchemaValidator( + const SchemaDocumentType& schemaDocument, + StateAllocator* allocator = 0, + size_t schemaStackCapacity = kDefaultSchemaStackCapacity, + size_t documentStackCapacity = kDefaultDocumentStackCapacity) + : + schemaDocument_(&schemaDocument), + root_(schemaDocument.GetRoot()), + stateAllocator_(allocator), + ownStateAllocator_(0), + schemaStack_(allocator, schemaStackCapacity), + documentStack_(allocator, documentStackCapacity), + outputHandler_(0), + error_(kObjectType), + currentError_(), + missingDependents_(), + valid_(true), + flags_(kValidateDefaultFlags), + depth_(0) + { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaValidator::GenericSchemaValidator"); + } + + //! Constructor with output handler. + /*! + \param schemaDocument The schema document to conform to. + \param allocator Optional allocator for storing internal validation states. + \param schemaStackCapacity Optional initial capacity of schema path stack. + \param documentStackCapacity Optional initial capacity of document path stack. + */ + GenericSchemaValidator( + const SchemaDocumentType& schemaDocument, + OutputHandler& outputHandler, + StateAllocator* allocator = 0, + size_t schemaStackCapacity = kDefaultSchemaStackCapacity, + size_t documentStackCapacity = kDefaultDocumentStackCapacity) + : + schemaDocument_(&schemaDocument), + root_(schemaDocument.GetRoot()), + stateAllocator_(allocator), + ownStateAllocator_(0), + schemaStack_(allocator, schemaStackCapacity), + documentStack_(allocator, documentStackCapacity), + outputHandler_(&outputHandler), + error_(kObjectType), + currentError_(), + missingDependents_(), + valid_(true), + flags_(kValidateDefaultFlags), + depth_(0) + { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaValidator::GenericSchemaValidator (output handler)"); + } + + //! Destructor. + ~GenericSchemaValidator() { + Reset(); + RAPIDJSON_DELETE(ownStateAllocator_); + } + + //! Reset the internal states. + void Reset() { + while (!schemaStack_.Empty()) + PopSchema(); + documentStack_.Clear(); + ResetError(); + } + + //! Reset the error state. + void ResetError() { + error_.SetObject(); + currentError_.SetNull(); + missingDependents_.SetNull(); + valid_ = true; + } + + //! Implementation of ISchemaValidator + void SetValidateFlags(unsigned flags) { + flags_ = flags; + } + virtual unsigned GetValidateFlags() const { + return flags_; + } + + virtual bool IsValid() const { + if (!valid_) return false; + if (GetContinueOnErrors() && !error_.ObjectEmpty()) return false; + return true; + } + //! End of Implementation of ISchemaValidator + + //! Gets the error object. + ValueType& GetError() { return error_; } + const ValueType& GetError() const { return error_; } + + //! Gets the JSON pointer pointed to the invalid schema. + // If reporting all errors, the stack will be empty. + PointerType GetInvalidSchemaPointer() const { + return schemaStack_.Empty() ? PointerType() : CurrentSchema().GetPointer(); + } + + //! Gets the keyword of invalid schema. + // If reporting all errors, the stack will be empty, so return "errors". + const Ch* GetInvalidSchemaKeyword() const { + if (!schemaStack_.Empty()) return CurrentContext().invalidKeyword; + if (GetContinueOnErrors() && !error_.ObjectEmpty()) return static_cast(GetErrorsString()); + return 0; + } + + //! Gets the error code of invalid schema. + // If reporting all errors, the stack will be empty, so return kValidateErrors. + ValidateErrorCode GetInvalidSchemaCode() const { + if (!schemaStack_.Empty()) return CurrentContext().invalidCode; + if (GetContinueOnErrors() && !error_.ObjectEmpty()) return kValidateErrors; + return kValidateErrorNone; + } + + //! Gets the JSON pointer pointed to the invalid value. + // If reporting all errors, the stack will be empty. + PointerType GetInvalidDocumentPointer() const { + if (documentStack_.Empty()) { + return PointerType(); + } + else { + return PointerType(documentStack_.template Bottom(), documentStack_.GetSize() / sizeof(Ch)); + } + } + + void NotMultipleOf(int64_t actual, const SValue& expected) { + AddNumberError(kValidateErrorMultipleOf, ValueType(actual).Move(), expected); + } + void NotMultipleOf(uint64_t actual, const SValue& expected) { + AddNumberError(kValidateErrorMultipleOf, ValueType(actual).Move(), expected); + } + void NotMultipleOf(double actual, const SValue& expected) { + AddNumberError(kValidateErrorMultipleOf, ValueType(actual).Move(), expected); + } + void AboveMaximum(int64_t actual, const SValue& expected, bool exclusive) { + AddNumberError(exclusive ? kValidateErrorExclusiveMaximum : kValidateErrorMaximum, ValueType(actual).Move(), expected, + exclusive ? &SchemaType::GetExclusiveMaximumString : 0); + } + void AboveMaximum(uint64_t actual, const SValue& expected, bool exclusive) { + AddNumberError(exclusive ? kValidateErrorExclusiveMaximum : kValidateErrorMaximum, ValueType(actual).Move(), expected, + exclusive ? &SchemaType::GetExclusiveMaximumString : 0); + } + void AboveMaximum(double actual, const SValue& expected, bool exclusive) { + AddNumberError(exclusive ? kValidateErrorExclusiveMaximum : kValidateErrorMaximum, ValueType(actual).Move(), expected, + exclusive ? &SchemaType::GetExclusiveMaximumString : 0); + } + void BelowMinimum(int64_t actual, const SValue& expected, bool exclusive) { + AddNumberError(exclusive ? kValidateErrorExclusiveMinimum : kValidateErrorMinimum, ValueType(actual).Move(), expected, + exclusive ? &SchemaType::GetExclusiveMinimumString : 0); + } + void BelowMinimum(uint64_t actual, const SValue& expected, bool exclusive) { + AddNumberError(exclusive ? kValidateErrorExclusiveMinimum : kValidateErrorMinimum, ValueType(actual).Move(), expected, + exclusive ? &SchemaType::GetExclusiveMinimumString : 0); + } + void BelowMinimum(double actual, const SValue& expected, bool exclusive) { + AddNumberError(exclusive ? kValidateErrorExclusiveMinimum : kValidateErrorMinimum, ValueType(actual).Move(), expected, + exclusive ? &SchemaType::GetExclusiveMinimumString : 0); + } + + void TooLong(const Ch* str, SizeType length, SizeType expected) { + AddNumberError(kValidateErrorMaxLength, + ValueType(str, length, GetStateAllocator()).Move(), SValue(expected).Move()); + } + void TooShort(const Ch* str, SizeType length, SizeType expected) { + AddNumberError(kValidateErrorMinLength, + ValueType(str, length, GetStateAllocator()).Move(), SValue(expected).Move()); + } + void DoesNotMatch(const Ch* str, SizeType length) { + currentError_.SetObject(); + currentError_.AddMember(GetActualString(), ValueType(str, length, GetStateAllocator()).Move(), GetStateAllocator()); + AddCurrentError(kValidateErrorPattern); + } + + void DisallowedItem(SizeType index) { + currentError_.SetObject(); + currentError_.AddMember(GetDisallowedString(), ValueType(index).Move(), GetStateAllocator()); + AddCurrentError(kValidateErrorAdditionalItems, true); + } + void TooFewItems(SizeType actualCount, SizeType expectedCount) { + AddNumberError(kValidateErrorMinItems, + ValueType(actualCount).Move(), SValue(expectedCount).Move()); + } + void TooManyItems(SizeType actualCount, SizeType expectedCount) { + AddNumberError(kValidateErrorMaxItems, + ValueType(actualCount).Move(), SValue(expectedCount).Move()); + } + void DuplicateItems(SizeType index1, SizeType index2) { + ValueType duplicates(kArrayType); + duplicates.PushBack(index1, GetStateAllocator()); + duplicates.PushBack(index2, GetStateAllocator()); + currentError_.SetObject(); + currentError_.AddMember(GetDuplicatesString(), duplicates, GetStateAllocator()); + AddCurrentError(kValidateErrorUniqueItems, true); + } + + void TooManyProperties(SizeType actualCount, SizeType expectedCount) { + AddNumberError(kValidateErrorMaxProperties, + ValueType(actualCount).Move(), SValue(expectedCount).Move()); + } + void TooFewProperties(SizeType actualCount, SizeType expectedCount) { + AddNumberError(kValidateErrorMinProperties, + ValueType(actualCount).Move(), SValue(expectedCount).Move()); + } + void StartMissingProperties() { + currentError_.SetArray(); + } + void AddMissingProperty(const SValue& name) { + currentError_.PushBack(ValueType(name, GetStateAllocator()).Move(), GetStateAllocator()); + } + bool EndMissingProperties() { + if (currentError_.Empty()) + return false; + ValueType error(kObjectType); + error.AddMember(GetMissingString(), currentError_, GetStateAllocator()); + currentError_ = error; + AddCurrentError(kValidateErrorRequired); + return true; + } + void PropertyViolations(ISchemaValidator** subvalidators, SizeType count) { + for (SizeType i = 0; i < count; ++i) + MergeError(static_cast(subvalidators[i])->GetError()); + } + void DisallowedProperty(const Ch* name, SizeType length) { + currentError_.SetObject(); + currentError_.AddMember(GetDisallowedString(), ValueType(name, length, GetStateAllocator()).Move(), GetStateAllocator()); + AddCurrentError(kValidateErrorAdditionalProperties, true); + } + + void StartDependencyErrors() { + currentError_.SetObject(); + } + void StartMissingDependentProperties() { + missingDependents_.SetArray(); + } + void AddMissingDependentProperty(const SValue& targetName) { + missingDependents_.PushBack(ValueType(targetName, GetStateAllocator()).Move(), GetStateAllocator()); + } + void EndMissingDependentProperties(const SValue& sourceName) { + if (!missingDependents_.Empty()) { + // Create equivalent 'required' error + ValueType error(kObjectType); + ValidateErrorCode code = kValidateErrorRequired; + error.AddMember(GetMissingString(), missingDependents_.Move(), GetStateAllocator()); + AddErrorCode(error, code); + AddErrorInstanceLocation(error, false); + // When appending to a pointer ensure its allocator is used + PointerType schemaRef = GetInvalidSchemaPointer().Append(SchemaType::GetValidateErrorKeyword(kValidateErrorDependencies), &GetInvalidSchemaPointer().GetAllocator()); + AddErrorSchemaLocation(error, schemaRef.Append(sourceName.GetString(), sourceName.GetStringLength(), &GetInvalidSchemaPointer().GetAllocator())); + ValueType wrapper(kObjectType); + wrapper.AddMember(ValueType(SchemaType::GetValidateErrorKeyword(code), GetStateAllocator()).Move(), error, GetStateAllocator()); + currentError_.AddMember(ValueType(sourceName, GetStateAllocator()).Move(), wrapper, GetStateAllocator()); + } + } + void AddDependencySchemaError(const SValue& sourceName, ISchemaValidator* subvalidator) { + currentError_.AddMember(ValueType(sourceName, GetStateAllocator()).Move(), + static_cast(subvalidator)->GetError(), GetStateAllocator()); + } + bool EndDependencyErrors() { + if (currentError_.ObjectEmpty()) + return false; + ValueType error(kObjectType); + error.AddMember(GetErrorsString(), currentError_, GetStateAllocator()); + currentError_ = error; + AddCurrentError(kValidateErrorDependencies); + return true; + } + + void DisallowedValue(const ValidateErrorCode code = kValidateErrorEnum) { + currentError_.SetObject(); + AddCurrentError(code); + } + void StartDisallowedType() { + currentError_.SetArray(); + } + void AddExpectedType(const typename SchemaType::ValueType& expectedType) { + currentError_.PushBack(ValueType(expectedType, GetStateAllocator()).Move(), GetStateAllocator()); + } + void EndDisallowedType(const typename SchemaType::ValueType& actualType) { + ValueType error(kObjectType); + error.AddMember(GetExpectedString(), currentError_, GetStateAllocator()); + error.AddMember(GetActualString(), ValueType(actualType, GetStateAllocator()).Move(), GetStateAllocator()); + currentError_ = error; + AddCurrentError(kValidateErrorType); + } + void NotAllOf(ISchemaValidator** subvalidators, SizeType count) { + // Treat allOf like oneOf and anyOf to match https://rapidjson.org/md_doc_schema.html#allOf-anyOf-oneOf + AddErrorArray(kValidateErrorAllOf, subvalidators, count); + //for (SizeType i = 0; i < count; ++i) { + // MergeError(static_cast(subvalidators[i])->GetError()); + //} + } + void NoneOf(ISchemaValidator** subvalidators, SizeType count) { + AddErrorArray(kValidateErrorAnyOf, subvalidators, count); + } + void NotOneOf(ISchemaValidator** subvalidators, SizeType count) { + AddErrorArray(kValidateErrorOneOf, subvalidators, count); + } + void MultipleOneOf(SizeType index1, SizeType index2) { + ValueType matches(kArrayType); + matches.PushBack(index1, GetStateAllocator()); + matches.PushBack(index2, GetStateAllocator()); + currentError_.SetObject(); + currentError_.AddMember(GetMatchesString(), matches, GetStateAllocator()); + AddCurrentError(kValidateErrorOneOfMatch); + } + void Disallowed() { + currentError_.SetObject(); + AddCurrentError(kValidateErrorNot); + } + void DisallowedWhenWriting() { + currentError_.SetObject(); + AddCurrentError(kValidateErrorReadOnly); + } + void DisallowedWhenReading() { + currentError_.SetObject(); + AddCurrentError(kValidateErrorWriteOnly); + } + +#define RAPIDJSON_STRING_(name, ...) \ + static const StringRefType& Get##name##String() {\ + static const Ch s[] = { __VA_ARGS__, '\0' };\ + static const StringRefType v(s, static_cast(sizeof(s) / sizeof(Ch) - 1)); \ + return v;\ + } + + RAPIDJSON_STRING_(InstanceRef, 'i', 'n', 's', 't', 'a', 'n', 'c', 'e', 'R', 'e', 'f') + RAPIDJSON_STRING_(SchemaRef, 's', 'c', 'h', 'e', 'm', 'a', 'R', 'e', 'f') + RAPIDJSON_STRING_(Expected, 'e', 'x', 'p', 'e', 'c', 't', 'e', 'd') + RAPIDJSON_STRING_(Actual, 'a', 'c', 't', 'u', 'a', 'l') + RAPIDJSON_STRING_(Disallowed, 'd', 'i', 's', 'a', 'l', 'l', 'o', 'w', 'e', 'd') + RAPIDJSON_STRING_(Missing, 'm', 'i', 's', 's', 'i', 'n', 'g') + RAPIDJSON_STRING_(Errors, 'e', 'r', 'r', 'o', 'r', 's') + RAPIDJSON_STRING_(ErrorCode, 'e', 'r', 'r', 'o', 'r', 'C', 'o', 'd', 'e') + RAPIDJSON_STRING_(ErrorMessage, 'e', 'r', 'r', 'o', 'r', 'M', 'e', 's', 's', 'a', 'g', 'e') + RAPIDJSON_STRING_(Duplicates, 'd', 'u', 'p', 'l', 'i', 'c', 'a', 't', 'e', 's') + RAPIDJSON_STRING_(Matches, 'm', 'a', 't', 'c', 'h', 'e', 's') + +#undef RAPIDJSON_STRING_ + +#define RAPIDJSON_SCHEMA_HANDLE_BEGIN_(method, arg1)\ + if (!valid_) return false; \ + if ((!BeginValue() && !GetContinueOnErrors()) || (!CurrentSchema().method arg1 && !GetContinueOnErrors())) {\ + *documentStack_.template Push() = '\0';\ + documentStack_.template Pop(1);\ + RAPIDJSON_SCHEMA_PRINT(InvalidDocument, documentStack_.template Bottom());\ + valid_ = false;\ + return valid_;\ + } + +#define RAPIDJSON_SCHEMA_HANDLE_PARALLEL_(method, arg2)\ + for (Context* context = schemaStack_.template Bottom(); context != schemaStack_.template End(); context++) {\ + if (context->hasher)\ + static_cast(context->hasher)->method arg2;\ + if (context->validators)\ + for (SizeType i_ = 0; i_ < context->validatorCount; i_++)\ + static_cast(context->validators[i_])->method arg2;\ + if (context->patternPropertiesValidators)\ + for (SizeType i_ = 0; i_ < context->patternPropertiesValidatorCount; i_++)\ + static_cast(context->patternPropertiesValidators[i_])->method arg2;\ + } + +#define RAPIDJSON_SCHEMA_HANDLE_END_(method, arg2)\ + valid_ = (EndValue() || GetContinueOnErrors()) && (!outputHandler_ || outputHandler_->method arg2);\ + return valid_; + +#define RAPIDJSON_SCHEMA_HANDLE_VALUE_(method, arg1, arg2) \ + RAPIDJSON_SCHEMA_HANDLE_BEGIN_ (method, arg1);\ + RAPIDJSON_SCHEMA_HANDLE_PARALLEL_(method, arg2);\ + RAPIDJSON_SCHEMA_HANDLE_END_ (method, arg2) + + bool Null() { RAPIDJSON_SCHEMA_HANDLE_VALUE_(Null, (CurrentContext()), ( )); } + bool Bool(bool b) { RAPIDJSON_SCHEMA_HANDLE_VALUE_(Bool, (CurrentContext(), b), (b)); } + bool Int(int i) { RAPIDJSON_SCHEMA_HANDLE_VALUE_(Int, (CurrentContext(), i), (i)); } + bool Uint(unsigned u) { RAPIDJSON_SCHEMA_HANDLE_VALUE_(Uint, (CurrentContext(), u), (u)); } + bool Int64(int64_t i) { RAPIDJSON_SCHEMA_HANDLE_VALUE_(Int64, (CurrentContext(), i), (i)); } + bool Uint64(uint64_t u) { RAPIDJSON_SCHEMA_HANDLE_VALUE_(Uint64, (CurrentContext(), u), (u)); } + bool Double(double d) { RAPIDJSON_SCHEMA_HANDLE_VALUE_(Double, (CurrentContext(), d), (d)); } + bool RawNumber(const Ch* str, SizeType length, bool copy) + { RAPIDJSON_SCHEMA_HANDLE_VALUE_(String, (CurrentContext(), str, length, copy), (str, length, copy)); } + bool String(const Ch* str, SizeType length, bool copy) + { RAPIDJSON_SCHEMA_HANDLE_VALUE_(String, (CurrentContext(), str, length, copy), (str, length, copy)); } + + bool StartObject() { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaValidator::StartObject"); + RAPIDJSON_SCHEMA_HANDLE_BEGIN_(StartObject, (CurrentContext())); + RAPIDJSON_SCHEMA_HANDLE_PARALLEL_(StartObject, ()); + valid_ = !outputHandler_ || outputHandler_->StartObject(); + return valid_; + } + + bool Key(const Ch* str, SizeType len, bool copy) { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaValidator::Key", str); + if (!valid_) return false; + AppendToken(str, len); + if (!CurrentSchema().Key(CurrentContext(), str, len, copy) && !GetContinueOnErrors()) { + valid_ = false; + return valid_; + } + RAPIDJSON_SCHEMA_HANDLE_PARALLEL_(Key, (str, len, copy)); + valid_ = !outputHandler_ || outputHandler_->Key(str, len, copy); + return valid_; + } + + bool EndObject(SizeType memberCount) { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaValidator::EndObject"); + if (!valid_) return false; + RAPIDJSON_SCHEMA_HANDLE_PARALLEL_(EndObject, (memberCount)); + if (!CurrentSchema().EndObject(CurrentContext(), memberCount) && !GetContinueOnErrors()) { + valid_ = false; + return valid_; + } + RAPIDJSON_SCHEMA_HANDLE_END_(EndObject, (memberCount)); + } + + bool StartArray() { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaValidator::StartArray"); + RAPIDJSON_SCHEMA_HANDLE_BEGIN_(StartArray, (CurrentContext())); + RAPIDJSON_SCHEMA_HANDLE_PARALLEL_(StartArray, ()); + valid_ = !outputHandler_ || outputHandler_->StartArray(); + return valid_; + } + + bool EndArray(SizeType elementCount) { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaValidator::EndArray"); + if (!valid_) return false; + RAPIDJSON_SCHEMA_HANDLE_PARALLEL_(EndArray, (elementCount)); + if (!CurrentSchema().EndArray(CurrentContext(), elementCount) && !GetContinueOnErrors()) { + valid_ = false; + return valid_; + } + RAPIDJSON_SCHEMA_HANDLE_END_(EndArray, (elementCount)); + } + +#undef RAPIDJSON_SCHEMA_HANDLE_BEGIN_ +#undef RAPIDJSON_SCHEMA_HANDLE_PARALLEL_ +#undef RAPIDJSON_SCHEMA_HANDLE_VALUE_ + + // Implementation of ISchemaStateFactory + virtual ISchemaValidator* CreateSchemaValidator(const SchemaType& root, const bool inheritContinueOnErrors) { + *documentStack_.template Push() = '\0'; + documentStack_.template Pop(1); + ISchemaValidator* sv = new (GetStateAllocator().Malloc(sizeof(GenericSchemaValidator))) GenericSchemaValidator(*schemaDocument_, root, documentStack_.template Bottom(), documentStack_.GetSize(), + depth_ + 1, + &GetStateAllocator()); + sv->SetValidateFlags(inheritContinueOnErrors ? GetValidateFlags() : GetValidateFlags() & ~static_cast(kValidateContinueOnErrorFlag)); + return sv; + } + + virtual void DestroySchemaValidator(ISchemaValidator* validator) { + GenericSchemaValidator* v = static_cast(validator); + v->~GenericSchemaValidator(); + StateAllocator::Free(v); + } + + virtual void* CreateHasher() { + return new (GetStateAllocator().Malloc(sizeof(HasherType))) HasherType(&GetStateAllocator()); + } + + virtual uint64_t GetHashCode(void* hasher) { + return static_cast(hasher)->GetHashCode(); + } + + virtual void DestroryHasher(void* hasher) { + HasherType* h = static_cast(hasher); + h->~HasherType(); + StateAllocator::Free(h); + } + + virtual void* MallocState(size_t size) { + return GetStateAllocator().Malloc(size); + } + + virtual void FreeState(void* p) { + StateAllocator::Free(p); + } + // End of implementation of ISchemaStateFactory + +private: + typedef typename SchemaType::Context Context; + typedef GenericValue, StateAllocator> HashCodeArray; + typedef internal::Hasher HasherType; + + GenericSchemaValidator( + const SchemaDocumentType& schemaDocument, + const SchemaType& root, + const char* basePath, size_t basePathSize, + unsigned depth, + StateAllocator* allocator = 0, + size_t schemaStackCapacity = kDefaultSchemaStackCapacity, + size_t documentStackCapacity = kDefaultDocumentStackCapacity) + : + schemaDocument_(&schemaDocument), + root_(root), + stateAllocator_(allocator), + ownStateAllocator_(0), + schemaStack_(allocator, schemaStackCapacity), + documentStack_(allocator, documentStackCapacity), + outputHandler_(0), + error_(kObjectType), + currentError_(), + missingDependents_(), + valid_(true), + flags_(kValidateDefaultFlags), + depth_(depth) + { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaValidator::GenericSchemaValidator (internal)", basePath && basePathSize ? basePath : ""); + if (basePath && basePathSize) + memcpy(documentStack_.template Push(basePathSize), basePath, basePathSize); + } + + StateAllocator& GetStateAllocator() { + if (!stateAllocator_) + stateAllocator_ = ownStateAllocator_ = RAPIDJSON_NEW(StateAllocator)(); + return *stateAllocator_; + } + + bool GetContinueOnErrors() const { + return flags_ & kValidateContinueOnErrorFlag; + } + + bool BeginValue() { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaValidator::BeginValue"); + if (schemaStack_.Empty()) + PushSchema(root_); + else { + if (CurrentContext().inArray) + internal::TokenHelper, Ch>::AppendIndexToken(documentStack_, CurrentContext().arrayElementIndex); + + if (!CurrentSchema().BeginValue(CurrentContext()) && !GetContinueOnErrors()) + return false; + + SizeType count = CurrentContext().patternPropertiesSchemaCount; + const SchemaType** sa = CurrentContext().patternPropertiesSchemas; + typename Context::PatternValidatorType patternValidatorType = CurrentContext().valuePatternValidatorType; + bool valueUniqueness = CurrentContext().valueUniqueness; + RAPIDJSON_ASSERT(CurrentContext().valueSchema); + PushSchema(*CurrentContext().valueSchema); + + if (count > 0) { + CurrentContext().objectPatternValidatorType = patternValidatorType; + ISchemaValidator**& va = CurrentContext().patternPropertiesValidators; + SizeType& validatorCount = CurrentContext().patternPropertiesValidatorCount; + va = static_cast(MallocState(sizeof(ISchemaValidator*) * count)); + std::memset(va, 0, sizeof(ISchemaValidator*) * count); + for (SizeType i = 0; i < count; i++) + va[validatorCount++] = CreateSchemaValidator(*sa[i], true); // Inherit continueOnError + } + + CurrentContext().arrayUniqueness = valueUniqueness; + } + return true; + } + + bool EndValue() { + RAPIDJSON_SCHEMA_PRINT(Method, "GenericSchemaValidator::EndValue"); + if (!CurrentSchema().EndValue(CurrentContext()) && !GetContinueOnErrors()) + return false; + + GenericStringBuffer sb; + schemaDocument_->GetPointer(&CurrentSchema()).StringifyUriFragment(sb); + *documentStack_.template Push() = '\0'; + documentStack_.template Pop(1); + RAPIDJSON_SCHEMA_PRINT(ValidatorPointers, sb.GetString(), documentStack_.template Bottom(), depth_); + void* hasher = CurrentContext().hasher; + uint64_t h = hasher && CurrentContext().arrayUniqueness ? static_cast(hasher)->GetHashCode() : 0; + + PopSchema(); + + if (!schemaStack_.Empty()) { + Context& context = CurrentContext(); + // Only check uniqueness if there is a hasher + if (hasher && context.valueUniqueness) { + HashCodeArray* a = static_cast(context.arrayElementHashCodes); + if (!a) + CurrentContext().arrayElementHashCodes = a = new (GetStateAllocator().Malloc(sizeof(HashCodeArray))) HashCodeArray(kArrayType); + for (typename HashCodeArray::ConstValueIterator itr = a->Begin(); itr != a->End(); ++itr) + if (itr->GetUint64() == h) { + DuplicateItems(static_cast(itr - a->Begin()), a->Size()); + // Cleanup before returning if continuing + if (GetContinueOnErrors()) { + a->PushBack(h, GetStateAllocator()); + while (!documentStack_.Empty() && *documentStack_.template Pop(1) != '/'); + } + RAPIDJSON_INVALID_KEYWORD_RETURN(kValidateErrorUniqueItems); + } + a->PushBack(h, GetStateAllocator()); + } + } + + // Remove the last token of document pointer + while (!documentStack_.Empty() && *documentStack_.template Pop(1) != '/') + ; + + return true; + } + + void AppendToken(const Ch* str, SizeType len) { + documentStack_.template Reserve(1 + len * 2); // worst case all characters are escaped as two characters + *documentStack_.template PushUnsafe() = '/'; + for (SizeType i = 0; i < len; i++) { + if (str[i] == '~') { + *documentStack_.template PushUnsafe() = '~'; + *documentStack_.template PushUnsafe() = '0'; + } + else if (str[i] == '/') { + *documentStack_.template PushUnsafe() = '~'; + *documentStack_.template PushUnsafe() = '1'; + } + else + *documentStack_.template PushUnsafe() = str[i]; + } + } + + RAPIDJSON_FORCEINLINE void PushSchema(const SchemaType& schema) { new (schemaStack_.template Push()) Context(*this, *this, &schema, flags_); } + + RAPIDJSON_FORCEINLINE void PopSchema() { + Context* c = schemaStack_.template Pop(1); + if (HashCodeArray* a = static_cast(c->arrayElementHashCodes)) { + a->~HashCodeArray(); + StateAllocator::Free(a); + } + c->~Context(); + } + + void AddErrorInstanceLocation(ValueType& result, bool parent) { + GenericStringBuffer sb; + PointerType instancePointer = GetInvalidDocumentPointer(); + ((parent && instancePointer.GetTokenCount() > 0) + ? PointerType(instancePointer.GetTokens(), instancePointer.GetTokenCount() - 1) + : instancePointer).StringifyUriFragment(sb); + ValueType instanceRef(sb.GetString(), static_cast(sb.GetSize() / sizeof(Ch)), + GetStateAllocator()); + result.AddMember(GetInstanceRefString(), instanceRef, GetStateAllocator()); + } + + void AddErrorSchemaLocation(ValueType& result, PointerType schema = PointerType()) { + GenericStringBuffer sb; + SizeType len = CurrentSchema().GetURI().GetStringLength(); + if (len) memcpy(sb.Push(len), CurrentSchema().GetURI().GetString(), len * sizeof(Ch)); + if (schema.GetTokenCount()) schema.StringifyUriFragment(sb); + else GetInvalidSchemaPointer().StringifyUriFragment(sb); + ValueType schemaRef(sb.GetString(), static_cast(sb.GetSize() / sizeof(Ch)), + GetStateAllocator()); + result.AddMember(GetSchemaRefString(), schemaRef, GetStateAllocator()); + } + + void AddErrorCode(ValueType& result, const ValidateErrorCode code) { + result.AddMember(GetErrorCodeString(), code, GetStateAllocator()); + } + + void AddError(ValueType& keyword, ValueType& error) { + typename ValueType::MemberIterator member = error_.FindMember(keyword); + if (member == error_.MemberEnd()) + error_.AddMember(keyword, error, GetStateAllocator()); + else { + if (member->value.IsObject()) { + ValueType errors(kArrayType); + errors.PushBack(member->value, GetStateAllocator()); + member->value = errors; + } + member->value.PushBack(error, GetStateAllocator()); + } + } + + void AddCurrentError(const ValidateErrorCode code, bool parent = false) { + AddErrorCode(currentError_, code); + AddErrorInstanceLocation(currentError_, parent); + AddErrorSchemaLocation(currentError_); + AddError(ValueType(SchemaType::GetValidateErrorKeyword(code), GetStateAllocator(), false).Move(), currentError_); + } + + void MergeError(ValueType& other) { + for (typename ValueType::MemberIterator it = other.MemberBegin(), end = other.MemberEnd(); it != end; ++it) { + AddError(it->name, it->value); + } + } + + void AddNumberError(const ValidateErrorCode code, ValueType& actual, const SValue& expected, + const typename SchemaType::ValueType& (*exclusive)() = 0) { + currentError_.SetObject(); + currentError_.AddMember(GetActualString(), actual, GetStateAllocator()); + currentError_.AddMember(GetExpectedString(), ValueType(expected, GetStateAllocator()).Move(), GetStateAllocator()); + if (exclusive) + currentError_.AddMember(ValueType(exclusive(), GetStateAllocator()).Move(), true, GetStateAllocator()); + AddCurrentError(code); + } + + void AddErrorArray(const ValidateErrorCode code, + ISchemaValidator** subvalidators, SizeType count) { + ValueType errors(kArrayType); + for (SizeType i = 0; i < count; ++i) + errors.PushBack(static_cast(subvalidators[i])->GetError(), GetStateAllocator()); + currentError_.SetObject(); + currentError_.AddMember(GetErrorsString(), errors, GetStateAllocator()); + AddCurrentError(code); + } + + const SchemaType& CurrentSchema() const { return *schemaStack_.template Top()->schema; } + Context& CurrentContext() { return *schemaStack_.template Top(); } + const Context& CurrentContext() const { return *schemaStack_.template Top(); } + + static const size_t kDefaultSchemaStackCapacity = 1024; + static const size_t kDefaultDocumentStackCapacity = 256; + const SchemaDocumentType* schemaDocument_; + const SchemaType& root_; + StateAllocator* stateAllocator_; + StateAllocator* ownStateAllocator_; + internal::Stack schemaStack_; //!< stack to store the current path of schema (BaseSchemaType *) + internal::Stack documentStack_; //!< stack to store the current path of validating document (Ch) + OutputHandler* outputHandler_; + ValueType error_; + ValueType currentError_; + ValueType missingDependents_; + bool valid_; + unsigned flags_; + unsigned depth_; +}; + +typedef GenericSchemaValidator SchemaValidator; + +/////////////////////////////////////////////////////////////////////////////// +// SchemaValidatingReader + +//! A helper class for parsing with validation. +/*! + This helper class is a functor, designed as a parameter of \ref GenericDocument::Populate(). + + \tparam parseFlags Combination of \ref ParseFlag. + \tparam InputStream Type of input stream, implementing Stream concept. + \tparam SourceEncoding Encoding of the input stream. + \tparam SchemaDocumentType Type of schema document. + \tparam StackAllocator Allocator type for stack. +*/ +template < + unsigned parseFlags, + typename InputStream, + typename SourceEncoding, + typename SchemaDocumentType = SchemaDocument, + typename StackAllocator = CrtAllocator> +class SchemaValidatingReader { +public: + typedef typename SchemaDocumentType::PointerType PointerType; + typedef typename InputStream::Ch Ch; + typedef GenericValue ValueType; + + //! Constructor + /*! + \param is Input stream. + \param sd Schema document. + */ + SchemaValidatingReader(InputStream& is, const SchemaDocumentType& sd) : is_(is), sd_(sd), invalidSchemaKeyword_(), invalidSchemaCode_(kValidateErrorNone), error_(kObjectType), isValid_(true) {} + + template + bool operator()(Handler& handler) { + GenericReader reader; + GenericSchemaValidator validator(sd_, handler); + parseResult_ = reader.template Parse(is_, validator); + + isValid_ = validator.IsValid(); + if (isValid_) { + invalidSchemaPointer_ = PointerType(); + invalidSchemaKeyword_ = 0; + invalidDocumentPointer_ = PointerType(); + error_.SetObject(); + } + else { + invalidSchemaPointer_ = validator.GetInvalidSchemaPointer(); + invalidSchemaKeyword_ = validator.GetInvalidSchemaKeyword(); + invalidSchemaCode_ = validator.GetInvalidSchemaCode(); + invalidDocumentPointer_ = validator.GetInvalidDocumentPointer(); + error_.CopyFrom(validator.GetError(), allocator_); + } + + return parseResult_; + } + + const ParseResult& GetParseResult() const { return parseResult_; } + bool IsValid() const { return isValid_; } + const PointerType& GetInvalidSchemaPointer() const { return invalidSchemaPointer_; } + const Ch* GetInvalidSchemaKeyword() const { return invalidSchemaKeyword_; } + const PointerType& GetInvalidDocumentPointer() const { return invalidDocumentPointer_; } + const ValueType& GetError() const { return error_; } + ValidateErrorCode GetInvalidSchemaCode() const { return invalidSchemaCode_; } + +private: + InputStream& is_; + const SchemaDocumentType& sd_; + + ParseResult parseResult_; + PointerType invalidSchemaPointer_; + const Ch* invalidSchemaKeyword_; + PointerType invalidDocumentPointer_; + ValidateErrorCode invalidSchemaCode_; + StackAllocator allocator_; + ValueType error_; + bool isValid_; +}; + +RAPIDJSON_NAMESPACE_END +RAPIDJSON_DIAG_POP + +#endif // RAPIDJSON_SCHEMA_H_ diff --git a/include/rapidjson/stream.h b/include/rapidjson/stream.h new file mode 100644 index 0000000000..1fd70915c5 --- /dev/null +++ b/include/rapidjson/stream.h @@ -0,0 +1,223 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "rapidjson.h" + +#ifndef RAPIDJSON_STREAM_H_ +#define RAPIDJSON_STREAM_H_ + +#include "encodings.h" + +RAPIDJSON_NAMESPACE_BEGIN + +/////////////////////////////////////////////////////////////////////////////// +// Stream + +/*! \class rapidjson::Stream + \brief Concept for reading and writing characters. + + For read-only stream, no need to implement PutBegin(), Put(), Flush() and PutEnd(). + + For write-only stream, only need to implement Put() and Flush(). + +\code +concept Stream { + typename Ch; //!< Character type of the stream. + + //! Read the current character from stream without moving the read cursor. + Ch Peek() const; + + //! Read the current character from stream and moving the read cursor to next character. + Ch Take(); + + //! Get the current read cursor. + //! \return Number of characters read from start. + size_t Tell(); + + //! Begin writing operation at the current read pointer. + //! \return The begin writer pointer. + Ch* PutBegin(); + + //! Write a character. + void Put(Ch c); + + //! Flush the buffer. + void Flush(); + + //! End the writing operation. + //! \param begin The begin write pointer returned by PutBegin(). + //! \return Number of characters written. + size_t PutEnd(Ch* begin); +} +\endcode +*/ + +//! Provides additional information for stream. +/*! + By using traits pattern, this type provides a default configuration for stream. + For custom stream, this type can be specialized for other configuration. + See TEST(Reader, CustomStringStream) in readertest.cpp for example. +*/ +template +struct StreamTraits { + //! Whether to make local copy of stream for optimization during parsing. + /*! + By default, for safety, streams do not use local copy optimization. + Stream that can be copied fast should specialize this, like StreamTraits. + */ + enum { copyOptimization = 0 }; +}; + +//! Reserve n characters for writing to a stream. +template +inline void PutReserve(Stream& stream, size_t count) { + (void)stream; + (void)count; +} + +//! Write character to a stream, presuming buffer is reserved. +template +inline void PutUnsafe(Stream& stream, typename Stream::Ch c) { + stream.Put(c); +} + +//! Put N copies of a character to a stream. +template +inline void PutN(Stream& stream, Ch c, size_t n) { + PutReserve(stream, n); + for (size_t i = 0; i < n; i++) + PutUnsafe(stream, c); +} + +/////////////////////////////////////////////////////////////////////////////// +// GenericStreamWrapper + +//! A Stream Wrapper +/*! \tThis string stream is a wrapper for any stream by just forwarding any + \treceived message to the origin stream. + \note implements Stream concept +*/ + +#if defined(_MSC_VER) && _MSC_VER <= 1800 +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(4702) // unreachable code +RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated +#endif + +template > +class GenericStreamWrapper { +public: + typedef typename Encoding::Ch Ch; + GenericStreamWrapper(InputStream& is): is_(is) {} + + Ch Peek() const { return is_.Peek(); } + Ch Take() { return is_.Take(); } + size_t Tell() { return is_.Tell(); } + Ch* PutBegin() { return is_.PutBegin(); } + void Put(Ch ch) { is_.Put(ch); } + void Flush() { is_.Flush(); } + size_t PutEnd(Ch* ch) { return is_.PutEnd(ch); } + + // wrapper for MemoryStream + const Ch* Peek4() const { return is_.Peek4(); } + + // wrapper for AutoUTFInputStream + UTFType GetType() const { return is_.GetType(); } + bool HasBOM() const { return is_.HasBOM(); } + +protected: + InputStream& is_; +}; + +#if defined(_MSC_VER) && _MSC_VER <= 1800 +RAPIDJSON_DIAG_POP +#endif + +/////////////////////////////////////////////////////////////////////////////// +// StringStream + +//! Read-only string stream. +/*! \note implements Stream concept +*/ +template +struct GenericStringStream { + typedef typename Encoding::Ch Ch; + + GenericStringStream(const Ch *src) : src_(src), head_(src) {} + + Ch Peek() const { return *src_; } + Ch Take() { return *src_++; } + size_t Tell() const { return static_cast(src_ - head_); } + + Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; } + void Put(Ch) { RAPIDJSON_ASSERT(false); } + void Flush() { RAPIDJSON_ASSERT(false); } + size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; } + + const Ch* src_; //!< Current read position. + const Ch* head_; //!< Original head of the string. +}; + +template +struct StreamTraits > { + enum { copyOptimization = 1 }; +}; + +//! String stream with UTF8 encoding. +typedef GenericStringStream > StringStream; + +/////////////////////////////////////////////////////////////////////////////// +// InsituStringStream + +//! A read-write string stream. +/*! This string stream is particularly designed for in-situ parsing. + \note implements Stream concept +*/ +template +struct GenericInsituStringStream { + typedef typename Encoding::Ch Ch; + + GenericInsituStringStream(Ch *src) : src_(src), dst_(0), head_(src) {} + + // Read + Ch Peek() { return *src_; } + Ch Take() { return *src_++; } + size_t Tell() { return static_cast(src_ - head_); } + + // Write + void Put(Ch c) { RAPIDJSON_ASSERT(dst_ != 0); *dst_++ = c; } + + Ch* PutBegin() { return dst_ = src_; } + size_t PutEnd(Ch* begin) { return static_cast(dst_ - begin); } + void Flush() {} + + Ch* Push(size_t count) { Ch* begin = dst_; dst_ += count; return begin; } + void Pop(size_t count) { dst_ -= count; } + + Ch* src_; + Ch* dst_; + Ch* head_; +}; + +template +struct StreamTraits > { + enum { copyOptimization = 1 }; +}; + +//! Insitu string stream with UTF8 encoding. +typedef GenericInsituStringStream > InsituStringStream; + +RAPIDJSON_NAMESPACE_END + +#endif // RAPIDJSON_STREAM_H_ diff --git a/include/rapidjson/stringbuffer.h b/include/rapidjson/stringbuffer.h new file mode 100644 index 0000000000..82ad3ca6bb --- /dev/null +++ b/include/rapidjson/stringbuffer.h @@ -0,0 +1,121 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_STRINGBUFFER_H_ +#define RAPIDJSON_STRINGBUFFER_H_ + +#include "stream.h" +#include "internal/stack.h" + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS +#include // std::move +#endif + +#include "internal/stack.h" + +#if defined(__clang__) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(c++98-compat) +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +//! Represents an in-memory output stream. +/*! + \tparam Encoding Encoding of the stream. + \tparam Allocator type for allocating memory buffer. + \note implements Stream concept +*/ +template +class GenericStringBuffer { +public: + typedef typename Encoding::Ch Ch; + + GenericStringBuffer(Allocator* allocator = 0, size_t capacity = kDefaultCapacity) : stack_(allocator, capacity) {} + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + GenericStringBuffer(GenericStringBuffer&& rhs) : stack_(std::move(rhs.stack_)) {} + GenericStringBuffer& operator=(GenericStringBuffer&& rhs) { + if (&rhs != this) + stack_ = std::move(rhs.stack_); + return *this; + } +#endif + + void Put(Ch c) { *stack_.template Push() = c; } + void PutUnsafe(Ch c) { *stack_.template PushUnsafe() = c; } + void Flush() {} + + void Clear() { stack_.Clear(); } + void ShrinkToFit() { + // Push and pop a null terminator. This is safe. + *stack_.template Push() = '\0'; + stack_.ShrinkToFit(); + stack_.template Pop(1); + } + + void Reserve(size_t count) { stack_.template Reserve(count); } + Ch* Push(size_t count) { return stack_.template Push(count); } + Ch* PushUnsafe(size_t count) { return stack_.template PushUnsafe(count); } + void Pop(size_t count) { stack_.template Pop(count); } + + const Ch* GetString() const { + // Push and pop a null terminator. This is safe. + *stack_.template Push() = '\0'; + stack_.template Pop(1); + + return stack_.template Bottom(); + } + + //! Get the size of string in bytes in the string buffer. + size_t GetSize() const { return stack_.GetSize(); } + + //! Get the length of string in Ch in the string buffer. + size_t GetLength() const { return stack_.GetSize() / sizeof(Ch); } + + static const size_t kDefaultCapacity = 256; + mutable internal::Stack stack_; + +private: + // Prohibit copy constructor & assignment operator. + GenericStringBuffer(const GenericStringBuffer&); + GenericStringBuffer& operator=(const GenericStringBuffer&); +}; + +//! String buffer with UTF8 encoding +typedef GenericStringBuffer > StringBuffer; + +template +inline void PutReserve(GenericStringBuffer& stream, size_t count) { + stream.Reserve(count); +} + +template +inline void PutUnsafe(GenericStringBuffer& stream, typename Encoding::Ch c) { + stream.PutUnsafe(c); +} + +//! Implement specialized version of PutN() with memset() for better performance. +template<> +inline void PutN(GenericStringBuffer >& stream, char c, size_t n) { + std::memset(stream.stack_.Push(n), c, n * sizeof(c)); +} + +RAPIDJSON_NAMESPACE_END + +#if defined(__clang__) +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_STRINGBUFFER_H_ diff --git a/include/rapidjson/uri.h b/include/rapidjson/uri.h new file mode 100644 index 0000000000..f93e508a4f --- /dev/null +++ b/include/rapidjson/uri.h @@ -0,0 +1,481 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// (C) Copyright IBM Corporation 2021 +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_URI_H_ +#define RAPIDJSON_URI_H_ + +#include "internal/strfunc.h" + +#if defined(__clang__) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(c++98-compat) +#elif defined(_MSC_VER) +RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +/////////////////////////////////////////////////////////////////////////////// +// GenericUri + +template +class GenericUri { +public: + typedef typename ValueType::Ch Ch; +#if RAPIDJSON_HAS_STDSTRING + typedef std::basic_string String; +#endif + + //! Constructors + GenericUri(Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() { + } + + GenericUri(const Ch* uri, SizeType len, Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() { + Parse(uri, len); + } + + GenericUri(const Ch* uri, Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() { + Parse(uri, internal::StrLen(uri)); + } + + // Use with specializations of GenericValue + template GenericUri(const T& uri, Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() { + const Ch* u = uri.template Get(); // TypeHelper from document.h + Parse(u, internal::StrLen(u)); + } + +#if RAPIDJSON_HAS_STDSTRING + GenericUri(const String& uri, Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() { + Parse(uri.c_str(), internal::StrLen(uri.c_str())); + } +#endif + + //! Copy constructor + GenericUri(const GenericUri& rhs) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(), ownAllocator_() { + *this = rhs; + } + + //! Copy constructor + GenericUri(const GenericUri& rhs, Allocator* allocator) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() { + *this = rhs; + } + + //! Destructor. + ~GenericUri() { + Free(); + RAPIDJSON_DELETE(ownAllocator_); + } + + //! Assignment operator + GenericUri& operator=(const GenericUri& rhs) { + if (this != &rhs) { + // Do not delete ownAllocator + Free(); + Allocate(rhs.GetStringLength()); + auth_ = CopyPart(scheme_, rhs.scheme_, rhs.GetSchemeStringLength()); + path_ = CopyPart(auth_, rhs.auth_, rhs.GetAuthStringLength()); + query_ = CopyPart(path_, rhs.path_, rhs.GetPathStringLength()); + frag_ = CopyPart(query_, rhs.query_, rhs.GetQueryStringLength()); + base_ = CopyPart(frag_, rhs.frag_, rhs.GetFragStringLength()); + uri_ = CopyPart(base_, rhs.base_, rhs.GetBaseStringLength()); + CopyPart(uri_, rhs.uri_, rhs.GetStringLength()); + } + return *this; + } + + //! Getters + // Use with specializations of GenericValue + template void Get(T& uri, Allocator& allocator) { + uri.template Set(this->GetString(), allocator); // TypeHelper from document.h + } + + const Ch* GetString() const { return uri_; } + SizeType GetStringLength() const { return uri_ == 0 ? 0 : internal::StrLen(uri_); } + const Ch* GetBaseString() const { return base_; } + SizeType GetBaseStringLength() const { return base_ == 0 ? 0 : internal::StrLen(base_); } + const Ch* GetSchemeString() const { return scheme_; } + SizeType GetSchemeStringLength() const { return scheme_ == 0 ? 0 : internal::StrLen(scheme_); } + const Ch* GetAuthString() const { return auth_; } + SizeType GetAuthStringLength() const { return auth_ == 0 ? 0 : internal::StrLen(auth_); } + const Ch* GetPathString() const { return path_; } + SizeType GetPathStringLength() const { return path_ == 0 ? 0 : internal::StrLen(path_); } + const Ch* GetQueryString() const { return query_; } + SizeType GetQueryStringLength() const { return query_ == 0 ? 0 : internal::StrLen(query_); } + const Ch* GetFragString() const { return frag_; } + SizeType GetFragStringLength() const { return frag_ == 0 ? 0 : internal::StrLen(frag_); } + +#if RAPIDJSON_HAS_STDSTRING + static String Get(const GenericUri& uri) { return String(uri.GetString(), uri.GetStringLength()); } + static String GetBase(const GenericUri& uri) { return String(uri.GetBaseString(), uri.GetBaseStringLength()); } + static String GetScheme(const GenericUri& uri) { return String(uri.GetSchemeString(), uri.GetSchemeStringLength()); } + static String GetAuth(const GenericUri& uri) { return String(uri.GetAuthString(), uri.GetAuthStringLength()); } + static String GetPath(const GenericUri& uri) { return String(uri.GetPathString(), uri.GetPathStringLength()); } + static String GetQuery(const GenericUri& uri) { return String(uri.GetQueryString(), uri.GetQueryStringLength()); } + static String GetFrag(const GenericUri& uri) { return String(uri.GetFragString(), uri.GetFragStringLength()); } +#endif + + //! Equality operators + bool operator==(const GenericUri& rhs) const { + return Match(rhs, true); + } + + bool operator!=(const GenericUri& rhs) const { + return !Match(rhs, true); + } + + bool Match(const GenericUri& uri, bool full = true) const { + Ch* s1; + Ch* s2; + if (full) { + s1 = uri_; + s2 = uri.uri_; + } else { + s1 = base_; + s2 = uri.base_; + } + if (s1 == s2) return true; + if (s1 == 0 || s2 == 0) return false; + return internal::StrCmp(s1, s2) == 0; + } + + //! Resolve this URI against another (base) URI in accordance with URI resolution rules. + // See https://tools.ietf.org/html/rfc3986 + // Use for resolving an id or $ref with an in-scope id. + // Returns a new GenericUri for the resolved URI. + GenericUri Resolve(const GenericUri& baseuri, Allocator* allocator = 0) { + GenericUri resuri; + resuri.allocator_ = allocator; + // Ensure enough space for combining paths + resuri.Allocate(GetStringLength() + baseuri.GetStringLength() + 1); // + 1 for joining slash + + if (!(GetSchemeStringLength() == 0)) { + // Use all of this URI + resuri.auth_ = CopyPart(resuri.scheme_, scheme_, GetSchemeStringLength()); + resuri.path_ = CopyPart(resuri.auth_, auth_, GetAuthStringLength()); + resuri.query_ = CopyPart(resuri.path_, path_, GetPathStringLength()); + resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength()); + resuri.RemoveDotSegments(); + } else { + // Use the base scheme + resuri.auth_ = CopyPart(resuri.scheme_, baseuri.scheme_, baseuri.GetSchemeStringLength()); + if (!(GetAuthStringLength() == 0)) { + // Use this auth, path, query + resuri.path_ = CopyPart(resuri.auth_, auth_, GetAuthStringLength()); + resuri.query_ = CopyPart(resuri.path_, path_, GetPathStringLength()); + resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength()); + resuri.RemoveDotSegments(); + } else { + // Use the base auth + resuri.path_ = CopyPart(resuri.auth_, baseuri.auth_, baseuri.GetAuthStringLength()); + if (GetPathStringLength() == 0) { + // Use the base path + resuri.query_ = CopyPart(resuri.path_, baseuri.path_, baseuri.GetPathStringLength()); + if (GetQueryStringLength() == 0) { + // Use the base query + resuri.frag_ = CopyPart(resuri.query_, baseuri.query_, baseuri.GetQueryStringLength()); + } else { + // Use this query + resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength()); + } + } else { + if (path_[0] == '/') { + // Absolute path - use all of this path + resuri.query_ = CopyPart(resuri.path_, path_, GetPathStringLength()); + resuri.RemoveDotSegments(); + } else { + // Relative path - append this path to base path after base path's last slash + size_t pos = 0; + if (!(baseuri.GetAuthStringLength() == 0) && baseuri.GetPathStringLength() == 0) { + resuri.path_[pos] = '/'; + pos++; + } + size_t lastslashpos = baseuri.GetPathStringLength(); + while (lastslashpos > 0) { + if (baseuri.path_[lastslashpos - 1] == '/') break; + lastslashpos--; + } + std::memcpy(&resuri.path_[pos], baseuri.path_, lastslashpos * sizeof(Ch)); + pos += lastslashpos; + resuri.query_ = CopyPart(&resuri.path_[pos], path_, GetPathStringLength()); + resuri.RemoveDotSegments(); + } + // Use this query + resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength()); + } + } + } + // Always use this frag + resuri.base_ = CopyPart(resuri.frag_, frag_, GetFragStringLength()); + + // Re-constitute base_ and uri_ + resuri.SetBase(); + resuri.uri_ = resuri.base_ + resuri.GetBaseStringLength() + 1; + resuri.SetUri(); + return resuri; + } + + //! Get the allocator of this GenericUri. + Allocator& GetAllocator() { return *allocator_; } + +private: + // Allocate memory for a URI + // Returns total amount allocated + std::size_t Allocate(std::size_t len) { + // Create own allocator if user did not supply. + if (!allocator_) + ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)(); + + // Allocate one block containing each part of the URI (5) plus base plus full URI, all null terminated. + // Order: scheme, auth, path, query, frag, base, uri + // Note need to set, increment, assign in 3 stages to avoid compiler warning bug. + size_t total = (3 * len + 7) * sizeof(Ch); + scheme_ = static_cast(allocator_->Malloc(total)); + *scheme_ = '\0'; + auth_ = scheme_; + auth_++; + *auth_ = '\0'; + path_ = auth_; + path_++; + *path_ = '\0'; + query_ = path_; + query_++; + *query_ = '\0'; + frag_ = query_; + frag_++; + *frag_ = '\0'; + base_ = frag_; + base_++; + *base_ = '\0'; + uri_ = base_; + uri_++; + *uri_ = '\0'; + return total; + } + + // Free memory for a URI + void Free() { + if (scheme_) { + Allocator::Free(scheme_); + scheme_ = 0; + } + } + + // Parse a URI into constituent scheme, authority, path, query, & fragment parts + // Supports URIs that match regex ^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))? as per + // https://tools.ietf.org/html/rfc3986 + void Parse(const Ch* uri, std::size_t len) { + std::size_t start = 0, pos1 = 0, pos2 = 0; + Allocate(len); + + // Look for scheme ([^:/?#]+):)? + if (start < len) { + while (pos1 < len) { + if (uri[pos1] == ':') break; + pos1++; + } + if (pos1 != len) { + while (pos2 < len) { + if (uri[pos2] == '/') break; + if (uri[pos2] == '?') break; + if (uri[pos2] == '#') break; + pos2++; + } + if (pos1 < pos2) { + pos1++; + std::memcpy(scheme_, &uri[start], pos1 * sizeof(Ch)); + scheme_[pos1] = '\0'; + start = pos1; + } + } + } + // Look for auth (//([^/?#]*))? + // Note need to set, increment, assign in 3 stages to avoid compiler warning bug. + auth_ = scheme_ + GetSchemeStringLength(); + auth_++; + *auth_ = '\0'; + if (start < len - 1 && uri[start] == '/' && uri[start + 1] == '/') { + pos2 = start + 2; + while (pos2 < len) { + if (uri[pos2] == '/') break; + if (uri[pos2] == '?') break; + if (uri[pos2] == '#') break; + pos2++; + } + std::memcpy(auth_, &uri[start], (pos2 - start) * sizeof(Ch)); + auth_[pos2 - start] = '\0'; + start = pos2; + } + // Look for path ([^?#]*) + // Note need to set, increment, assign in 3 stages to avoid compiler warning bug. + path_ = auth_ + GetAuthStringLength(); + path_++; + *path_ = '\0'; + if (start < len) { + pos2 = start; + while (pos2 < len) { + if (uri[pos2] == '?') break; + if (uri[pos2] == '#') break; + pos2++; + } + if (start != pos2) { + std::memcpy(path_, &uri[start], (pos2 - start) * sizeof(Ch)); + path_[pos2 - start] = '\0'; + if (path_[0] == '/') + RemoveDotSegments(); // absolute path - normalize + start = pos2; + } + } + // Look for query (\?([^#]*))? + // Note need to set, increment, assign in 3 stages to avoid compiler warning bug. + query_ = path_ + GetPathStringLength(); + query_++; + *query_ = '\0'; + if (start < len && uri[start] == '?') { + pos2 = start + 1; + while (pos2 < len) { + if (uri[pos2] == '#') break; + pos2++; + } + if (start != pos2) { + std::memcpy(query_, &uri[start], (pos2 - start) * sizeof(Ch)); + query_[pos2 - start] = '\0'; + start = pos2; + } + } + // Look for fragment (#(.*))? + // Note need to set, increment, assign in 3 stages to avoid compiler warning bug. + frag_ = query_ + GetQueryStringLength(); + frag_++; + *frag_ = '\0'; + if (start < len && uri[start] == '#') { + std::memcpy(frag_, &uri[start], (len - start) * sizeof(Ch)); + frag_[len - start] = '\0'; + } + + // Re-constitute base_ and uri_ + base_ = frag_ + GetFragStringLength() + 1; + SetBase(); + uri_ = base_ + GetBaseStringLength() + 1; + SetUri(); + } + + // Reconstitute base + void SetBase() { + Ch* next = base_; + std::memcpy(next, scheme_, GetSchemeStringLength() * sizeof(Ch)); + next+= GetSchemeStringLength(); + std::memcpy(next, auth_, GetAuthStringLength() * sizeof(Ch)); + next+= GetAuthStringLength(); + std::memcpy(next, path_, GetPathStringLength() * sizeof(Ch)); + next+= GetPathStringLength(); + std::memcpy(next, query_, GetQueryStringLength() * sizeof(Ch)); + next+= GetQueryStringLength(); + *next = '\0'; + } + + // Reconstitute uri + void SetUri() { + Ch* next = uri_; + std::memcpy(next, base_, GetBaseStringLength() * sizeof(Ch)); + next+= GetBaseStringLength(); + std::memcpy(next, frag_, GetFragStringLength() * sizeof(Ch)); + next+= GetFragStringLength(); + *next = '\0'; + } + + // Copy a part from one GenericUri to another + // Return the pointer to the next part to be copied to + Ch* CopyPart(Ch* to, Ch* from, std::size_t len) { + RAPIDJSON_ASSERT(to != 0); + RAPIDJSON_ASSERT(from != 0); + std::memcpy(to, from, len * sizeof(Ch)); + to[len] = '\0'; + Ch* next = to + len + 1; + return next; + } + + // Remove . and .. segments from the path_ member. + // https://tools.ietf.org/html/rfc3986 + // This is done in place as we are only removing segments. + void RemoveDotSegments() { + std::size_t pathlen = GetPathStringLength(); + std::size_t pathpos = 0; // Position in path_ + std::size_t newpos = 0; // Position in new path_ + + // Loop through each segment in original path_ + while (pathpos < pathlen) { + // Get next segment, bounded by '/' or end + size_t slashpos = 0; + while ((pathpos + slashpos) < pathlen) { + if (path_[pathpos + slashpos] == '/') break; + slashpos++; + } + // Check for .. and . segments + if (slashpos == 2 && path_[pathpos] == '.' && path_[pathpos + 1] == '.') { + // Backup a .. segment in the new path_ + // We expect to find a previously added slash at the end or nothing + RAPIDJSON_ASSERT(newpos == 0 || path_[newpos - 1] == '/'); + size_t lastslashpos = newpos; + // Make sure we don't go beyond the start segment + if (lastslashpos > 1) { + // Find the next to last slash and back up to it + lastslashpos--; + while (lastslashpos > 0) { + if (path_[lastslashpos - 1] == '/') break; + lastslashpos--; + } + // Set the new path_ position + newpos = lastslashpos; + } + } else if (slashpos == 1 && path_[pathpos] == '.') { + // Discard . segment, leaves new path_ unchanged + } else { + // Move any other kind of segment to the new path_ + RAPIDJSON_ASSERT(newpos <= pathpos); + std::memmove(&path_[newpos], &path_[pathpos], slashpos * sizeof(Ch)); + newpos += slashpos; + // Add slash if not at end + if ((pathpos + slashpos) < pathlen) { + path_[newpos] = '/'; + newpos++; + } + } + // Move to next segment + pathpos += slashpos + 1; + } + path_[newpos] = '\0'; + } + + Ch* uri_; // Everything + Ch* base_; // Everything except fragment + Ch* scheme_; // Includes the : + Ch* auth_; // Includes the // + Ch* path_; // Absolute if starts with / + Ch* query_; // Includes the ? + Ch* frag_; // Includes the # + + Allocator* allocator_; //!< The current allocator. It is either user-supplied or equal to ownAllocator_. + Allocator* ownAllocator_; //!< Allocator owned by this Uri. +}; + +//! GenericUri for Value (UTF-8, default allocator). +typedef GenericUri Uri; + +RAPIDJSON_NAMESPACE_END + +#if defined(__clang__) +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_URI_H_ diff --git a/include/rapidjson/writer.h b/include/rapidjson/writer.h new file mode 100644 index 0000000000..632e02ce74 --- /dev/null +++ b/include/rapidjson/writer.h @@ -0,0 +1,721 @@ +// Tencent is pleased to support the open source community by making RapidJSON available. +// +// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. +// +// Licensed under the MIT License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://opensource.org/licenses/MIT +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef RAPIDJSON_WRITER_H_ +#define RAPIDJSON_WRITER_H_ + +#include "stream.h" +#include "internal/clzll.h" +#include "internal/meta.h" +#include "internal/stack.h" +#include "internal/strfunc.h" +#include "internal/dtoa.h" +#include "internal/itoa.h" +#include "stringbuffer.h" +#include // placement new + +#if defined(RAPIDJSON_SIMD) && defined(_MSC_VER) +#include +#pragma intrinsic(_BitScanForward) +#endif +#ifdef RAPIDJSON_SSE42 +#include +#elif defined(RAPIDJSON_SSE2) +#include +#elif defined(RAPIDJSON_NEON) +#include +#endif + +#ifdef __clang__ +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(padded) +RAPIDJSON_DIAG_OFF(unreachable-code) +RAPIDJSON_DIAG_OFF(c++98-compat) +#elif defined(_MSC_VER) +RAPIDJSON_DIAG_PUSH +RAPIDJSON_DIAG_OFF(4127) // conditional expression is constant +#endif + +RAPIDJSON_NAMESPACE_BEGIN + +/////////////////////////////////////////////////////////////////////////////// +// WriteFlag + +/*! \def RAPIDJSON_WRITE_DEFAULT_FLAGS + \ingroup RAPIDJSON_CONFIG + \brief User-defined kWriteDefaultFlags definition. + + User can define this as any \c WriteFlag combinations. +*/ +#ifndef RAPIDJSON_WRITE_DEFAULT_FLAGS +#define RAPIDJSON_WRITE_DEFAULT_FLAGS kWriteNoFlags +#endif + +//! Combination of writeFlags +enum WriteFlag { + kWriteNoFlags = 0, //!< No flags are set. + kWriteValidateEncodingFlag = 1, //!< Validate encoding of JSON strings. + kWriteNanAndInfFlag = 2, //!< Allow writing of Infinity, -Infinity and NaN. + kWriteNanAndInfNullFlag = 4, //!< Allow writing of Infinity, -Infinity and NaN as null. + kWriteDefaultFlags = RAPIDJSON_WRITE_DEFAULT_FLAGS //!< Default write flags. Can be customized by defining RAPIDJSON_WRITE_DEFAULT_FLAGS +}; + +//! JSON writer +/*! Writer implements the concept Handler. + It generates JSON text by events to an output os. + + User may programmatically calls the functions of a writer to generate JSON text. + + On the other side, a writer can also be passed to objects that generates events, + + for example Reader::Parse() and Document::Accept(). + + \tparam OutputStream Type of output stream. + \tparam SourceEncoding Encoding of source string. + \tparam TargetEncoding Encoding of output stream. + \tparam StackAllocator Type of allocator for allocating memory of stack. + \note implements Handler concept +*/ +template, typename TargetEncoding = UTF8<>, typename StackAllocator = CrtAllocator, unsigned writeFlags = kWriteDefaultFlags> +class Writer { +public: + typedef typename SourceEncoding::Ch Ch; + + static const int kDefaultMaxDecimalPlaces = 324; + + //! Constructor + /*! \param os Output stream. + \param stackAllocator User supplied allocator. If it is null, it will create a private one. + \param levelDepth Initial capacity of stack. + */ + explicit + Writer(OutputStream& os, StackAllocator* stackAllocator = 0, size_t levelDepth = kDefaultLevelDepth) : + os_(&os), level_stack_(stackAllocator, levelDepth * sizeof(Level)), maxDecimalPlaces_(kDefaultMaxDecimalPlaces), hasRoot_(false) {} + + explicit + Writer(StackAllocator* allocator = 0, size_t levelDepth = kDefaultLevelDepth) : + os_(0), level_stack_(allocator, levelDepth * sizeof(Level)), maxDecimalPlaces_(kDefaultMaxDecimalPlaces), hasRoot_(false) {} + +#if RAPIDJSON_HAS_CXX11_RVALUE_REFS + Writer(Writer&& rhs) : + os_(rhs.os_), level_stack_(std::move(rhs.level_stack_)), maxDecimalPlaces_(rhs.maxDecimalPlaces_), hasRoot_(rhs.hasRoot_) { + rhs.os_ = 0; + } +#endif + + //! Reset the writer with a new stream. + /*! + This function reset the writer with a new stream and default settings, + in order to make a Writer object reusable for output multiple JSONs. + + \param os New output stream. + \code + Writer writer(os1); + writer.StartObject(); + // ... + writer.EndObject(); + + writer.Reset(os2); + writer.StartObject(); + // ... + writer.EndObject(); + \endcode + */ + void Reset(OutputStream& os) { + os_ = &os; + hasRoot_ = false; + level_stack_.Clear(); + } + + //! Checks whether the output is a complete JSON. + /*! + A complete JSON has a complete root object or array. + */ + bool IsComplete() const { + return hasRoot_ && level_stack_.Empty(); + } + + int GetMaxDecimalPlaces() const { + return maxDecimalPlaces_; + } + + //! Sets the maximum number of decimal places for double output. + /*! + This setting truncates the output with specified number of decimal places. + + For example, + + \code + writer.SetMaxDecimalPlaces(3); + writer.StartArray(); + writer.Double(0.12345); // "0.123" + writer.Double(0.0001); // "0.0" + writer.Double(1.234567890123456e30); // "1.234567890123456e30" (do not truncate significand for positive exponent) + writer.Double(1.23e-4); // "0.0" (do truncate significand for negative exponent) + writer.EndArray(); + \endcode + + The default setting does not truncate any decimal places. You can restore to this setting by calling + \code + writer.SetMaxDecimalPlaces(Writer::kDefaultMaxDecimalPlaces); + \endcode + */ + void SetMaxDecimalPlaces(int maxDecimalPlaces) { + maxDecimalPlaces_ = maxDecimalPlaces; + } + + /*!@name Implementation of Handler + \see Handler + */ + //@{ + + bool Null() { Prefix(kNullType); return EndValue(WriteNull()); } + bool Bool(bool b) { Prefix(b ? kTrueType : kFalseType); return EndValue(WriteBool(b)); } + bool Int(int i) { Prefix(kNumberType); return EndValue(WriteInt(i)); } + bool Uint(unsigned u) { Prefix(kNumberType); return EndValue(WriteUint(u)); } + bool Int64(int64_t i64) { Prefix(kNumberType); return EndValue(WriteInt64(i64)); } + bool Uint64(uint64_t u64) { Prefix(kNumberType); return EndValue(WriteUint64(u64)); } + + //! Writes the given \c double value to the stream + /*! + \param d The value to be written. + \return Whether it is succeed. + */ + bool Double(double d) { Prefix(kNumberType); return EndValue(WriteDouble(d)); } + + bool RawNumber(const Ch* str, SizeType length, bool copy = false) { + RAPIDJSON_ASSERT(str != 0); + (void)copy; + Prefix(kNumberType); + return EndValue(WriteString(str, length)); + } + + bool String(const Ch* str, SizeType length, bool copy = false) { + RAPIDJSON_ASSERT(str != 0); + (void)copy; + Prefix(kStringType); + return EndValue(WriteString(str, length)); + } + +#if RAPIDJSON_HAS_STDSTRING + bool String(const std::basic_string& str) { + return String(str.data(), SizeType(str.size())); + } +#endif + + bool StartObject() { + Prefix(kObjectType); + new (level_stack_.template Push()) Level(false); + return WriteStartObject(); + } + + bool Key(const Ch* str, SizeType length, bool copy = false) { return String(str, length, copy); } + +#if RAPIDJSON_HAS_STDSTRING + bool Key(const std::basic_string& str) + { + return Key(str.data(), SizeType(str.size())); + } +#endif + + bool EndObject(SizeType memberCount = 0) { + (void)memberCount; + RAPIDJSON_ASSERT(level_stack_.GetSize() >= sizeof(Level)); // not inside an Object + RAPIDJSON_ASSERT(!level_stack_.template Top()->inArray); // currently inside an Array, not Object + RAPIDJSON_ASSERT(0 == level_stack_.template Top()->valueCount % 2); // Object has a Key without a Value + level_stack_.template Pop(1); + return EndValue(WriteEndObject()); + } + + bool StartArray() { + Prefix(kArrayType); + new (level_stack_.template Push()) Level(true); + return WriteStartArray(); + } + + bool EndArray(SizeType elementCount = 0) { + (void)elementCount; + RAPIDJSON_ASSERT(level_stack_.GetSize() >= sizeof(Level)); + RAPIDJSON_ASSERT(level_stack_.template Top()->inArray); + level_stack_.template Pop(1); + return EndValue(WriteEndArray()); + } + //@} + + /*! @name Convenience extensions */ + //@{ + + //! Simpler but slower overload. + bool String(const Ch* const& str) { return String(str, internal::StrLen(str)); } + bool Key(const Ch* const& str) { return Key(str, internal::StrLen(str)); } + + //@} + + //! Write a raw JSON value. + /*! + For user to write a stringified JSON as a value. + + \param json A well-formed JSON value. It should not contain null character within [0, length - 1] range. + \param length Length of the json. + \param type Type of the root of json. + */ + bool RawValue(const Ch* json, size_t length, Type type) { + RAPIDJSON_ASSERT(json != 0); + Prefix(type); + return EndValue(WriteRawValue(json, length)); + } + + //! Flush the output stream. + /*! + Allows the user to flush the output stream immediately. + */ + void Flush() { + os_->Flush(); + } + + static const size_t kDefaultLevelDepth = 32; + +protected: + //! Information for each nested level + struct Level { + Level(bool inArray_) : valueCount(0), inArray(inArray_) {} + size_t valueCount; //!< number of values in this level + bool inArray; //!< true if in array, otherwise in object + }; + + bool WriteNull() { + PutReserve(*os_, 4); + PutUnsafe(*os_, 'n'); PutUnsafe(*os_, 'u'); PutUnsafe(*os_, 'l'); PutUnsafe(*os_, 'l'); return true; + } + + bool WriteBool(bool b) { + if (b) { + PutReserve(*os_, 4); + PutUnsafe(*os_, 't'); PutUnsafe(*os_, 'r'); PutUnsafe(*os_, 'u'); PutUnsafe(*os_, 'e'); + } + else { + PutReserve(*os_, 5); + PutUnsafe(*os_, 'f'); PutUnsafe(*os_, 'a'); PutUnsafe(*os_, 'l'); PutUnsafe(*os_, 's'); PutUnsafe(*os_, 'e'); + } + return true; + } + + bool WriteInt(int i) { + char buffer[11]; + const char* end = internal::i32toa(i, buffer); + PutReserve(*os_, static_cast(end - buffer)); + for (const char* p = buffer; p != end; ++p) + PutUnsafe(*os_, static_cast(*p)); + return true; + } + + bool WriteUint(unsigned u) { + char buffer[10]; + const char* end = internal::u32toa(u, buffer); + PutReserve(*os_, static_cast(end - buffer)); + for (const char* p = buffer; p != end; ++p) + PutUnsafe(*os_, static_cast(*p)); + return true; + } + + bool WriteInt64(int64_t i64) { + char buffer[21]; + const char* end = internal::i64toa(i64, buffer); + PutReserve(*os_, static_cast(end - buffer)); + for (const char* p = buffer; p != end; ++p) + PutUnsafe(*os_, static_cast(*p)); + return true; + } + + bool WriteUint64(uint64_t u64) { + char buffer[20]; + char* end = internal::u64toa(u64, buffer); + PutReserve(*os_, static_cast(end - buffer)); + for (char* p = buffer; p != end; ++p) + PutUnsafe(*os_, static_cast(*p)); + return true; + } + + bool WriteDouble(double d) { + if (internal::Double(d).IsNanOrInf()) { + if (!(writeFlags & kWriteNanAndInfFlag) && !(writeFlags & kWriteNanAndInfNullFlag)) + return false; + if (writeFlags & kWriteNanAndInfNullFlag) { + PutReserve(*os_, 4); + PutUnsafe(*os_, 'n'); PutUnsafe(*os_, 'u'); PutUnsafe(*os_, 'l'); PutUnsafe(*os_, 'l'); + return true; + } + if (internal::Double(d).IsNan()) { + PutReserve(*os_, 3); + PutUnsafe(*os_, 'N'); PutUnsafe(*os_, 'a'); PutUnsafe(*os_, 'N'); + return true; + } + if (internal::Double(d).Sign()) { + PutReserve(*os_, 9); + PutUnsafe(*os_, '-'); + } + else + PutReserve(*os_, 8); + PutUnsafe(*os_, 'I'); PutUnsafe(*os_, 'n'); PutUnsafe(*os_, 'f'); + PutUnsafe(*os_, 'i'); PutUnsafe(*os_, 'n'); PutUnsafe(*os_, 'i'); PutUnsafe(*os_, 't'); PutUnsafe(*os_, 'y'); + return true; + } + + char buffer[25]; + char* end = internal::dtoa(d, buffer, maxDecimalPlaces_); + PutReserve(*os_, static_cast(end - buffer)); + for (char* p = buffer; p != end; ++p) + PutUnsafe(*os_, static_cast(*p)); + return true; + } + + bool WriteString(const Ch* str, SizeType length) { + static const typename OutputStream::Ch hexDigits[16] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' }; + static const char escape[256] = { +#define Z16 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 + //0 1 2 3 4 5 6 7 8 9 A B C D E F + 'u', 'u', 'u', 'u', 'u', 'u', 'u', 'u', 'b', 't', 'n', 'u', 'f', 'r', 'u', 'u', // 00 + 'u', 'u', 'u', 'u', 'u', 'u', 'u', 'u', 'u', 'u', 'u', 'u', 'u', 'u', 'u', 'u', // 10 + 0, 0, '"', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20 + Z16, Z16, // 30~4F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,'\\', 0, 0, 0, // 50 + Z16, Z16, Z16, Z16, Z16, Z16, Z16, Z16, Z16, Z16 // 60~FF +#undef Z16 + }; + + if (TargetEncoding::supportUnicode) + PutReserve(*os_, 2 + length * 6); // "\uxxxx..." + else + PutReserve(*os_, 2 + length * 12); // "\uxxxx\uyyyy..." + + PutUnsafe(*os_, '\"'); + GenericStringStream is(str); + while (ScanWriteUnescapedString(is, length)) { + const Ch c = is.Peek(); + if (!TargetEncoding::supportUnicode && static_cast(c) >= 0x80) { + // Unicode escaping + unsigned codepoint; + if (RAPIDJSON_UNLIKELY(!SourceEncoding::Decode(is, &codepoint))) + return false; + PutUnsafe(*os_, '\\'); + PutUnsafe(*os_, 'u'); + if (codepoint <= 0xD7FF || (codepoint >= 0xE000 && codepoint <= 0xFFFF)) { + PutUnsafe(*os_, hexDigits[(codepoint >> 12) & 15]); + PutUnsafe(*os_, hexDigits[(codepoint >> 8) & 15]); + PutUnsafe(*os_, hexDigits[(codepoint >> 4) & 15]); + PutUnsafe(*os_, hexDigits[(codepoint ) & 15]); + } + else { + RAPIDJSON_ASSERT(codepoint >= 0x010000 && codepoint <= 0x10FFFF); + // Surrogate pair + unsigned s = codepoint - 0x010000; + unsigned lead = (s >> 10) + 0xD800; + unsigned trail = (s & 0x3FF) + 0xDC00; + PutUnsafe(*os_, hexDigits[(lead >> 12) & 15]); + PutUnsafe(*os_, hexDigits[(lead >> 8) & 15]); + PutUnsafe(*os_, hexDigits[(lead >> 4) & 15]); + PutUnsafe(*os_, hexDigits[(lead ) & 15]); + PutUnsafe(*os_, '\\'); + PutUnsafe(*os_, 'u'); + PutUnsafe(*os_, hexDigits[(trail >> 12) & 15]); + PutUnsafe(*os_, hexDigits[(trail >> 8) & 15]); + PutUnsafe(*os_, hexDigits[(trail >> 4) & 15]); + PutUnsafe(*os_, hexDigits[(trail ) & 15]); + } + } + else if ((sizeof(Ch) == 1 || static_cast(c) < 256) && RAPIDJSON_UNLIKELY(escape[static_cast(c)])) { + is.Take(); + PutUnsafe(*os_, '\\'); + PutUnsafe(*os_, static_cast(escape[static_cast(c)])); + if (escape[static_cast(c)] == 'u') { + PutUnsafe(*os_, '0'); + PutUnsafe(*os_, '0'); + PutUnsafe(*os_, hexDigits[static_cast(c) >> 4]); + PutUnsafe(*os_, hexDigits[static_cast(c) & 0xF]); + } + } + else if (RAPIDJSON_UNLIKELY(!(writeFlags & kWriteValidateEncodingFlag ? + Transcoder::Validate(is, *os_) : + Transcoder::TranscodeUnsafe(is, *os_)))) + return false; + } + PutUnsafe(*os_, '\"'); + return true; + } + + bool ScanWriteUnescapedString(GenericStringStream& is, size_t length) { + return RAPIDJSON_LIKELY(is.Tell() < length); + } + + bool WriteStartObject() { os_->Put('{'); return true; } + bool WriteEndObject() { os_->Put('}'); return true; } + bool WriteStartArray() { os_->Put('['); return true; } + bool WriteEndArray() { os_->Put(']'); return true; } + + bool WriteRawValue(const Ch* json, size_t length) { + PutReserve(*os_, length); + GenericStringStream is(json); + while (RAPIDJSON_LIKELY(is.Tell() < length)) { + RAPIDJSON_ASSERT(is.Peek() != '\0'); + if (RAPIDJSON_UNLIKELY(!(writeFlags & kWriteValidateEncodingFlag ? + Transcoder::Validate(is, *os_) : + Transcoder::TranscodeUnsafe(is, *os_)))) + return false; + } + return true; + } + + void Prefix(Type type) { + (void)type; + if (RAPIDJSON_LIKELY(level_stack_.GetSize() != 0)) { // this value is not at root + Level* level = level_stack_.template Top(); + if (level->valueCount > 0) { + if (level->inArray) + os_->Put(','); // add comma if it is not the first element in array + else // in object + os_->Put((level->valueCount % 2 == 0) ? ',' : ':'); + } + if (!level->inArray && level->valueCount % 2 == 0) + RAPIDJSON_ASSERT(type == kStringType); // if it's in object, then even number should be a name + level->valueCount++; + } + else { + RAPIDJSON_ASSERT(!hasRoot_); // Should only has one and only one root. + hasRoot_ = true; + } + } + + // Flush the value if it is the top level one. + bool EndValue(bool ret) { + if (RAPIDJSON_UNLIKELY(level_stack_.Empty())) // end of json text + Flush(); + return ret; + } + + OutputStream* os_; + internal::Stack level_stack_; + int maxDecimalPlaces_; + bool hasRoot_; + +private: + // Prohibit copy constructor & assignment operator. + Writer(const Writer&); + Writer& operator=(const Writer&); +}; + +// Full specialization for StringStream to prevent memory copying + +template<> +inline bool Writer::WriteInt(int i) { + char *buffer = os_->Push(11); + const char* end = internal::i32toa(i, buffer); + os_->Pop(static_cast(11 - (end - buffer))); + return true; +} + +template<> +inline bool Writer::WriteUint(unsigned u) { + char *buffer = os_->Push(10); + const char* end = internal::u32toa(u, buffer); + os_->Pop(static_cast(10 - (end - buffer))); + return true; +} + +template<> +inline bool Writer::WriteInt64(int64_t i64) { + char *buffer = os_->Push(21); + const char* end = internal::i64toa(i64, buffer); + os_->Pop(static_cast(21 - (end - buffer))); + return true; +} + +template<> +inline bool Writer::WriteUint64(uint64_t u) { + char *buffer = os_->Push(20); + const char* end = internal::u64toa(u, buffer); + os_->Pop(static_cast(20 - (end - buffer))); + return true; +} + +template<> +inline bool Writer::WriteDouble(double d) { + if (internal::Double(d).IsNanOrInf()) { + // Note: This code path can only be reached if (RAPIDJSON_WRITE_DEFAULT_FLAGS & kWriteNanAndInfFlag). + if (!(kWriteDefaultFlags & kWriteNanAndInfFlag)) + return false; + if (kWriteDefaultFlags & kWriteNanAndInfNullFlag) { + PutReserve(*os_, 4); + PutUnsafe(*os_, 'n'); PutUnsafe(*os_, 'u'); PutUnsafe(*os_, 'l'); PutUnsafe(*os_, 'l'); + return true; + } + if (internal::Double(d).IsNan()) { + PutReserve(*os_, 3); + PutUnsafe(*os_, 'N'); PutUnsafe(*os_, 'a'); PutUnsafe(*os_, 'N'); + return true; + } + if (internal::Double(d).Sign()) { + PutReserve(*os_, 9); + PutUnsafe(*os_, '-'); + } + else + PutReserve(*os_, 8); + PutUnsafe(*os_, 'I'); PutUnsafe(*os_, 'n'); PutUnsafe(*os_, 'f'); + PutUnsafe(*os_, 'i'); PutUnsafe(*os_, 'n'); PutUnsafe(*os_, 'i'); PutUnsafe(*os_, 't'); PutUnsafe(*os_, 'y'); + return true; + } + + char *buffer = os_->Push(25); + char* end = internal::dtoa(d, buffer, maxDecimalPlaces_); + os_->Pop(static_cast(25 - (end - buffer))); + return true; +} + +#if defined(RAPIDJSON_SSE2) || defined(RAPIDJSON_SSE42) +template<> +inline bool Writer::ScanWriteUnescapedString(StringStream& is, size_t length) { + if (length < 16) + return RAPIDJSON_LIKELY(is.Tell() < length); + + if (!RAPIDJSON_LIKELY(is.Tell() < length)) + return false; + + const char* p = is.src_; + const char* end = is.head_ + length; + const char* nextAligned = reinterpret_cast((reinterpret_cast(p) + 15) & static_cast(~15)); + const char* endAligned = reinterpret_cast(reinterpret_cast(end) & static_cast(~15)); + if (nextAligned > end) + return true; + + while (p != nextAligned) + if (*p < 0x20 || *p == '\"' || *p == '\\') { + is.src_ = p; + return RAPIDJSON_LIKELY(is.Tell() < length); + } + else + os_->PutUnsafe(*p++); + + // The rest of string using SIMD + static const char dquote[16] = { '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"' }; + static const char bslash[16] = { '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\', '\\' }; + static const char space[16] = { 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F }; + const __m128i dq = _mm_loadu_si128(reinterpret_cast(&dquote[0])); + const __m128i bs = _mm_loadu_si128(reinterpret_cast(&bslash[0])); + const __m128i sp = _mm_loadu_si128(reinterpret_cast(&space[0])); + + for (; p != endAligned; p += 16) { + const __m128i s = _mm_load_si128(reinterpret_cast(p)); + const __m128i t1 = _mm_cmpeq_epi8(s, dq); + const __m128i t2 = _mm_cmpeq_epi8(s, bs); + const __m128i t3 = _mm_cmpeq_epi8(_mm_max_epu8(s, sp), sp); // s < 0x20 <=> max(s, 0x1F) == 0x1F + const __m128i x = _mm_or_si128(_mm_or_si128(t1, t2), t3); + unsigned short r = static_cast(_mm_movemask_epi8(x)); + if (RAPIDJSON_UNLIKELY(r != 0)) { // some of characters is escaped + SizeType len; +#ifdef _MSC_VER // Find the index of first escaped + unsigned long offset; + _BitScanForward(&offset, r); + len = offset; +#else + len = static_cast(__builtin_ffs(r) - 1); +#endif + char* q = reinterpret_cast(os_->PushUnsafe(len)); + for (size_t i = 0; i < len; i++) + q[i] = p[i]; + + p += len; + break; + } + _mm_storeu_si128(reinterpret_cast<__m128i *>(os_->PushUnsafe(16)), s); + } + + is.src_ = p; + return RAPIDJSON_LIKELY(is.Tell() < length); +} +#elif defined(RAPIDJSON_NEON) +template<> +inline bool Writer::ScanWriteUnescapedString(StringStream& is, size_t length) { + if (length < 16) + return RAPIDJSON_LIKELY(is.Tell() < length); + + if (!RAPIDJSON_LIKELY(is.Tell() < length)) + return false; + + const char* p = is.src_; + const char* end = is.head_ + length; + const char* nextAligned = reinterpret_cast((reinterpret_cast(p) + 15) & static_cast(~15)); + const char* endAligned = reinterpret_cast(reinterpret_cast(end) & static_cast(~15)); + if (nextAligned > end) + return true; + + while (p != nextAligned) + if (*p < 0x20 || *p == '\"' || *p == '\\') { + is.src_ = p; + return RAPIDJSON_LIKELY(is.Tell() < length); + } + else + os_->PutUnsafe(*p++); + + // The rest of string using SIMD + const uint8x16_t s0 = vmovq_n_u8('"'); + const uint8x16_t s1 = vmovq_n_u8('\\'); + const uint8x16_t s2 = vmovq_n_u8('\b'); + const uint8x16_t s3 = vmovq_n_u8(32); + + for (; p != endAligned; p += 16) { + const uint8x16_t s = vld1q_u8(reinterpret_cast(p)); + uint8x16_t x = vceqq_u8(s, s0); + x = vorrq_u8(x, vceqq_u8(s, s1)); + x = vorrq_u8(x, vceqq_u8(s, s2)); + x = vorrq_u8(x, vcltq_u8(s, s3)); + + x = vrev64q_u8(x); // Rev in 64 + uint64_t low = vgetq_lane_u64(vreinterpretq_u64_u8(x), 0); // extract + uint64_t high = vgetq_lane_u64(vreinterpretq_u64_u8(x), 1); // extract + + SizeType len = 0; + bool escaped = false; + if (low == 0) { + if (high != 0) { + uint32_t lz = internal::clzll(high); + len = 8 + (lz >> 3); + escaped = true; + } + } else { + uint32_t lz = internal::clzll(low); + len = lz >> 3; + escaped = true; + } + if (RAPIDJSON_UNLIKELY(escaped)) { // some of characters is escaped + char* q = reinterpret_cast(os_->PushUnsafe(len)); + for (size_t i = 0; i < len; i++) + q[i] = p[i]; + + p += len; + break; + } + vst1q_u8(reinterpret_cast(os_->PushUnsafe(16)), s); + } + + is.src_ = p; + return RAPIDJSON_LIKELY(is.Tell() < length); +} +#endif // RAPIDJSON_NEON + +RAPIDJSON_NAMESPACE_END + +#if defined(_MSC_VER) || defined(__clang__) +RAPIDJSON_DIAG_POP +#endif + +#endif // RAPIDJSON_RAPIDJSON_H_ diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index ea2834ae62..74391ded28 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,7 +1,2 @@ -#!/bin/bash -set -euo pipefail -IFS=$'\n\t' - - -find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' -git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|.hpp|.inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' +find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | grep -v 'build/' | grep -v 'include/rapidjson'| xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' +git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|.hpp|.inc|include/rapidjson/")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' diff --git a/test/ck_tile/moe_sorting/test_moe_sorting_cases.inc b/test/ck_tile/moe_sorting/test_moe_sorting_cases.inc old mode 100755 new mode 100644 diff --git a/test/ck_tile/permute/test_permute_cases.inc b/test/ck_tile/permute/test_permute_cases.inc old mode 100755 new mode 100644 diff --git a/test/ck_tile/smoothquant/test_smoothquant_cases.inc b/test/ck_tile/smoothquant/test_smoothquant_cases.inc old mode 100755 new mode 100644 From 47d020a99322e99ffce7b51fb83cf4d8e2b4d30f Mon Sep 17 00:00:00 2001 From: msaffari-amd Date: Wed, 3 Sep 2025 09:34:11 +0200 Subject: [PATCH 013/404] refactor: use snake_case naming in ck_tile/core components (#2766) --- .../algorithm/static_encoding_pattern.hpp | 81 ++++++++-------- .../batched_transpose_common_policy.hpp | 12 +-- .../pipeline/batched_transpose_policy.hpp | 12 +-- .../ops/epilogue/cshuffle_epilogue.hpp | 15 +-- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 94 +++++++++---------- .../gemm_aquant_pipeline_ag_bg_cr_policy.hpp | 53 +++++------ .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 4 +- .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 14 +-- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 4 +- .../pipeline/gemm_group_quant_utils.hpp | 13 +-- .../test_print_static_encoding_pattern.cpp | 24 +++-- 11 files changed, 171 insertions(+), 155 deletions(-) diff --git a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp index 1f6c389090..c96daf3e99 100644 --- a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp +++ b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp @@ -104,7 +104,7 @@ enum struct tile_distribution_pattern block_raked, }; -struct TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern { }; @@ -126,7 +126,7 @@ template -struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_2d : public tile_distribution_encoding_pattern { }; @@ -136,12 +136,13 @@ template -struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_2d + : public tile_distribution_encoding_pattern { // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! @@ -165,7 +166,7 @@ struct TileDistributionEncodingPattern2D -struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_2d + : public tile_distribution_encoding_pattern { static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); @@ -244,7 +246,7 @@ struct TileDistributionEncodingPattern2D, @@ -255,7 +257,7 @@ struct TileDistributionEncodingPattern2D>{}); // -> } - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution() { return make_static_tile_distribution( tile_distribution_encoding, @@ -273,12 +275,13 @@ template -struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_2d + : public tile_distribution_encoding_pattern { // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! @@ -295,7 +298,7 @@ struct TileDistributionEncodingPattern2D, @@ -306,7 +309,7 @@ struct TileDistributionEncodingPattern2D>{}); // -> } - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution() { return make_static_tile_distribution( tile_distribution_encoding, @@ -336,21 +339,21 @@ template -CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D&) +CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d&) { - using PatternType = TileDistributionEncodingPattern2D; + using PatternType = tile_distribution_encoding_pattern_2d; - printf("TileDistributionEncodingPattern2D: ", BlockSize, YPerTile, diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp index 3b8d5a142e..9e2a67f940 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp @@ -21,12 +21,12 @@ struct BatchedTransposeCommonPolicy constexpr index_t kVectorSize = Problem::VectorSizeInput; static_assert((kLeadDimPerBlock * kVectorSize) % kBlockSize == 0, ""); - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); } }; diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp index e6bbc709ea..137584c3e8 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp @@ -18,12 +18,12 @@ struct BatchedTransposePolicy : public BatchedTransposeCommonPolicy constexpr index_t NPerBlock = Problem::kNPerBlock; constexpr index_t VecLoadSize = Problem::VectorSizeOutput; - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_shuffled_2d_static_tile_distribution(); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 1d0a4c42f4..7510df091c 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -291,13 +291,14 @@ struct CShuffleEpilogue "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); using TileEncodingPattern = - TileDistributionEncodingPattern2D; - constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); + tile_distribution_encoding_pattern_2d; + constexpr auto dram_tile_distribution = + TileEncodingPattern::make_2d_static_tile_distribution(); auto d_dram_windows = generate_tuple( [&](auto idx) { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 40ee952b1b..8d47ab878e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -168,11 +168,11 @@ struct UniversalGemmBasePolicy { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t VecLoadSize = GetVectorSizeB(); - using TileEncodingPattern = TileDistributionEncodingPattern2D; + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; constexpr auto BK0 = number{}; constexpr auto BK1 = number{}; @@ -494,24 +494,24 @@ struct UniversalGemmBasePolicy // Tile: MPerBlock X KPerBlock if constexpr(std::is_same_v) { - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); } // Tile: KPerBlock X MPerBlock else { - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); } } @@ -530,24 +530,24 @@ struct UniversalGemmBasePolicy // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) { - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); } // Tile: NPerBlock X KPerBlock else { - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); } } @@ -562,13 +562,13 @@ struct UniversalGemmBasePolicy constexpr index_t VecLoadSize = GetVectorSizeA(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_shuffled_2d_static_tile_distribution(); } template @@ -582,13 +582,13 @@ struct UniversalGemmBasePolicy constexpr index_t VecLoadSize = GetVectorSizeB(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_shuffled_2d_static_tile_distribution(); } template diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp index 52c99f8e99..926f63b5a9 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -55,44 +55,43 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC static_assert(std::is_same_v); if constexpr(PreshuffleQuant) { - using TileEncodingPattern = - TileDistributionEncodingPatternAQ; + using TileEncodingPattern = tile_distribution_encoding_pattern_aq< + BlockGemmShape, + WarpGemm, + BlockSize, + MPerBlock / WarpGemm::kM, + ck_tile::integer_least_multiple(WarpGemm::kM * KPerBlockAQ, get_warp_size()), + KPerBlockAQ, + VecLoadSize, + PreshuffleQuant>; - return TileEncodingPattern::Make2DStaticTileDistribution(); + return TileEncodingPattern::make_2d_static_tile_distribution(); } else { if constexpr(Problem::TransposeC) { using TileEncodingPatternTransposeC = - TileDistributionEncodingPatternAQTransposedC; - return TileEncodingPatternTransposeC::Make2DStaticTileDistribution(); + tile_distribution_encoding_pattern_aq_transposed_c; + return TileEncodingPatternTransposeC::make_2d_static_tile_distribution(); } else { - using TileEncodingPattern = TileDistributionEncodingPatternAQ; + using TileEncodingPattern = tile_distribution_encoding_pattern_aq; - return TileEncodingPattern::Make2DStaticTileDistribution(); + return TileEncodingPattern::make_2d_static_tile_distribution(); } } } diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 037cef0553..5ce4268dca 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -330,7 +330,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + Policy::template make_shuffled_2d_static_tile_distribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } @@ -342,7 +342,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + Policy::template make_shuffled_2d_static_tile_distribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index ff986d86fb..eea8038edf 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -52,14 +52,14 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC Problem::TransposeC>; static_assert(std::is_same_v); - using TileEncodingPattern = TileDistributionEncodingPatternBQ; + using TileEncodingPattern = tile_distribution_encoding_pattern_bq; - return TileEncodingPattern::Make2DStaticTileDistribution(); + return TileEncodingPattern::make_2d_static_tile_distribution(); } template diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 7ce6598b80..8f191f0f94 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -326,7 +326,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + Policy::template make_shuffled_2d_static_tile_distribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } @@ -338,7 +338,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + Policy::template make_shuffled_2d_static_tile_distribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp index 56a906a6bc..54b64c34be 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp @@ -53,7 +53,7 @@ template -struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding_pattern { static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); static constexpr index_t warp_size = get_warp_size(); @@ -70,7 +70,7 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter // KWarps > 1 isn't supported static_assert(KWarps == 1); - CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { if constexpr(PreshuffleQuant) { @@ -119,7 +119,8 @@ template -struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_aq_transposed_c + : public tile_distribution_encoding_pattern { // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); @@ -152,7 +153,7 @@ struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEnc static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y."); - CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { return make_static_tile_distribution( tile_distribution_encoding, @@ -171,7 +172,7 @@ template -struct TileDistributionEncodingPatternBQ : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); @@ -204,7 +205,7 @@ struct TileDistributionEncodingPatternBQ : public TileDistributionEncodingPatter static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y."); - CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { return make_static_tile_distribution( tile_distribution_encoding, diff --git a/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp b/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp index 3ff23e2e11..3b1b6ffb6d 100644 --- a/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp +++ b/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp @@ -32,13 +32,17 @@ TEST_F(PrintStaticEncodingPatternTest, PrintThreadRakedPattern) { // Test printing thread raked pattern using PatternType = - TileDistributionEncodingPattern2D<64, 8, 16, 4, tile_distribution_pattern::thread_raked>; + tile_distribution_encoding_pattern_2d<64, + 8, + 16, + 4, + tile_distribution_pattern::thread_raked>; PatternType pattern; std::string output = CapturePrintOutput(pattern); // Verify the output contains expected information - EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos); + EXPECT_TRUE(output.find("tile_distribution_encoding_pattern_2d") != std::string::npos); EXPECT_TRUE(output.find("BlockSize:64") != std::string::npos); EXPECT_TRUE(output.find("YPerTile:8") != std::string::npos); EXPECT_TRUE(output.find("XPerTile:16") != std::string::npos); @@ -52,13 +56,17 @@ TEST_F(PrintStaticEncodingPatternTest, PrintWarpRakedPattern) { // Test printing warp raked pattern using PatternType = - TileDistributionEncodingPattern2D<128, 16, 32, 8, tile_distribution_pattern::warp_raked>; + tile_distribution_encoding_pattern_2d<128, + 16, + 32, + 8, + tile_distribution_pattern::warp_raked>; PatternType pattern; std::string output = CapturePrintOutput(pattern); // Verify the output contains expected information - EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos); + EXPECT_TRUE(output.find("tile_distribution_encoding_pattern_2d") != std::string::npos); EXPECT_TRUE(output.find("BlockSize:128") != std::string::npos); EXPECT_TRUE(output.find("YPerTile:16") != std::string::npos); EXPECT_TRUE(output.find("XPerTile:32") != std::string::npos); @@ -72,13 +80,17 @@ TEST_F(PrintStaticEncodingPatternTest, PrintBlockRakedPattern) { // Test printing block raked pattern using PatternType = - TileDistributionEncodingPattern2D<256, 32, 64, 16, tile_distribution_pattern::block_raked>; + tile_distribution_encoding_pattern_2d<256, + 32, + 64, + 16, + tile_distribution_pattern::block_raked>; PatternType pattern; std::string output = CapturePrintOutput(pattern); // Verify the output contains expected information - EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos); + EXPECT_TRUE(output.find("tile_distribution_encoding_pattern_2d") != std::string::npos); EXPECT_TRUE(output.find("BlockSize:256") != std::string::npos); EXPECT_TRUE(output.find("YPerTile:32") != std::string::npos); EXPECT_TRUE(output.find("XPerTile:64") != std::string::npos); From 0282d98412fb6abac5c7e650355491f504b5ca4d Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Wed, 3 Sep 2025 13:38:17 -0700 Subject: [PATCH 014/404] [CK TILE] Stream-K tile partitioner (#2708) * initial commit for skeleton code * replaced skeleton code with old streamk b2c map functions from old CK, still need to clean up the code * fixed up code to match CK Tile convention: data type changes, naming changes, etc. * change for num_sk_blocks data type * formatting fix * minor fixes * moved reduction argument to template * resolved comments from PR review: standardizing naming, pruning unneeded code * resolve errors from merge of device op PR: moved enum to common file * switching to uint32_t due to implementation constraints: divmod only takes uint32_t and mixing signed and unsigned types causes problems * unsigned type fix * add const qualifier * added documentation for template parameters * documentation edit --- include/ck_tile/ops/common.hpp | 1 + include/ck_tile/ops/common/streamk_common.hpp | 14 + .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 451 ++++++++++++++++++ .../ops/gemm/kernel/streamk_gemm_kernel.hpp | 62 ++- 4 files changed, 494 insertions(+), 34 deletions(-) create mode 100644 include/ck_tile/ops/common/streamk_common.hpp diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 027e2fdd94..7c6adc3ec2 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -6,3 +6,4 @@ #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/common/streamk_common.hpp b/include/ck_tile/ops/common/streamk_common.hpp new file mode 100644 index 0000000000..5dbe6223c4 --- /dev/null +++ b/include/ck_tile/ops/common/streamk_common.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +enum StreamKReductionStrategy : uint32_t +{ + Atomic = 0u, + Reduction = 1u +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index b621468e92..92ae6411a5 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -9,6 +9,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" namespace ck_tile { @@ -364,4 +365,454 @@ struct GemmSpatiallyLocalTilePartitioner index_t N; }; +/** + * @brief Stream-K tile partitioner that dynamically balances work across workgroups + * + * This partitioner is responsible for mapping workgroups to tiles in the C tensor + * for the Stream-K algorithm which decomposes the GEMM problem + * into smaller work units and distributes them more evenly across available blocks, + * improving load balancing especially for cases where the K dimension is large. + * + * @tparam BlockGemmShapeType A class providing basic GEMM parameters. + * @tparam ReductionStrategy A class that defines the reduction strategy for the results in + * the C Tensor. + * @tparam TileSwizzleSubM A value that defines the size of the swizzle group along the m + * dimension, where the swizzle group denotes consecutive tiles down a column. For instance a + * swizzle group of 8 denotes tiles 0, 1, ..., 7, map to tiles [0,0], [1,0], ..., [7,0] in the C + * tensor. + */ +template +struct StreamKTilePartitioner +{ + using BlockGemmShape = BlockGemmShapeType; + + static constexpr uint32_t MPerBlock = BlockGemmShape::kM; + static constexpr uint32_t NPerBlock = BlockGemmShape::kN; + static constexpr uint32_t KPerBlock = BlockGemmShape::kK; + + CK_TILE_HOST_DEVICE StreamKTilePartitioner() noexcept = delete; + + /** + * @brief Construct Stream-K tile partitioner with problem dimensions + */ + CK_TILE_HOST_DEVICE StreamKTilePartitioner(uint32_t M, + uint32_t N, + uint32_t K, + uint32_t num_cu, + uint32_t occupancy, + uint32_t sk_blocks = 0xffffffff) noexcept + : M_(M), N_(N), K_(K) + { + num_tile_m_ = integer_divide_ceil(M, MPerBlock); + num_tile_n_ = integer_divide_ceil(N, NPerBlock); + num_tile_k_ = integer_divide_ceil(K, KPerBlock); + + constexpr uint32_t min_k_iters_per_sk_block = 2; + uint32_t num_tiles = num_tile_m_ * num_tile_n_; + k_iters_per_tile = mdiv(num_tile_k_); + + // one cu can hold one wg at one time, from the whole cZ's point of view + // if number of wg is same as num_cu, we call it 1 dispatch + // if number of wg is 2x num_cu, we call it 2 dispatches. + // one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial + // dispatch) + // + const uint32_t full_dispatches = num_tiles / num_cu; + const uint32_t full_dispatch_tiles = full_dispatches * num_cu; + const uint32_t partial_dispatch_tiles = num_tiles - full_dispatch_tiles; + + uint32_t sk_occupancy = occupancy; + uint32_t dp_tiles = full_dispatch_tiles; + uint32_t sk_tiles = partial_dispatch_tiles; + + if(full_dispatches < occupancy) + { + // in this case, we allocate all blocks as sk blocks + // sk_occupancy = occupancy - full_dispatches; + sk_occupancy = 1; + dp_tiles = full_dispatch_tiles; + sk_tiles = partial_dispatch_tiles; + } + else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1)) + { + // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ... + // occupancy = 3, full_dispatches = 5, 8, 11 ... + // occupancy = 4, full_dispatches = 7, 11 ... + sk_occupancy = 1; // left 1 slot for sk occupancy + dp_tiles = full_dispatch_tiles; + sk_tiles = partial_dispatch_tiles; + } + else + { + // otherwise, we reduce 1 dispatch from dp, together with partial dispatch, + // to construct sk dispatch + sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy); + dp_tiles = full_dispatch_tiles - num_cu; + sk_tiles = partial_dispatch_tiles + num_cu; + } + + // uint32_t dp_iters_per_block = k_iters_per_tile.get(); + uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles; + uint32_t dp_num_blocks = 0; + + { + const uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1); + const uint32_t max_sk_tiles = + (sk_tiles >= num_cu) ? num_cu * sk_occupancy + : min(num_cu, sk_total_iters / min_k_iters_per_sk_block); + + // if use dp for sk-block, how many iters do we need + const uint32_t dp_for_sk_iters = k_iters_per_tile.get(); + + uint32_t best_sk_score = + std::numeric_limits::max(); // we need to find the smallest sk iters + for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; + tentative_sk_blocks++) + { + const uint32_t tentative_sk_iters_per_block = + (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks; + const uint32_t tentative_sk_iters = tentative_sk_iters_per_block; + const uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles; + + // the more sk_blocks_per_tile, the worse the overhead + uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile; + if(tentative_sk_blocks % sk_tiles != 0) + { + // penalty for uneven divide + cross_sk_blocks_overhead += + sk_blocks_per_tile * tentative_sk_iters_per_block / 50; + } + + const uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead; + + if(tentative_sk_score < best_sk_score) + { + best_sk_score = tentative_sk_score; + sk_num_blocks = tentative_sk_blocks; + } + } + + if(best_sk_score >= dp_for_sk_iters) + { + sk_num_blocks = 0; + } + + // give a chance to control num of sk blocks + sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks; + + if(sk_num_blocks == 0) + { + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + + dp_num_blocks = num_tiles; // all tile to be dp block + dp_start_block_idx = 0; + sk_total_iters = 0; // clear this tiles + } + else + { + // k_iters_per_sk_block is the floor of avg each ck block loop over tiles. + // we need to decide how many iters for each sk block + // let m = k_iters_per_sk_block + // some of the sk block (little) will cover m iters, some (big) will cover m+1 + // we have + // 1) l + b = sk_blocks + // 2) l * m + b * (m + 1) = sk_total_iters + // => (l + b) * m + b = sk_total_iters + // => sk_blocks * m + b = sk_total_iters + // => b = sk_total_iters - m * sk_blocks + // NOTE: big could be zero + const uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + + dp_num_blocks = dp_tiles; + dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu; + } + } + n_tiles = mdiv2(num_tile_n_); + reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; + + if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) + { + const uint32_t upper_big = lcm(k_iters_per_big_block, k_iters_per_tile.get()); + const uint32_t upper_little = lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); + equiv_tiles_big = mdiv(upper_big / k_iters_per_tile.get()); + equiv_tiles_little = mdiv(upper_little / k_iters_per_tile.get()); + } + } + + /** + * @brief Calculate optimal grid size for Stream-K + */ + CK_TILE_HOST auto GridSize() const noexcept -> dim3 + { + if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) + { + return dim3(reduction_start_block_idx + GetSkTiles(), 1, 1); + } + else + return dim3(reduction_start_block_idx, 1, 1); + } + + /** + * @brief Calculate number of loop iterations over K dimension for given work unit + */ + CK_TILE_HOST_DEVICE static auto GetLoopNum(uint32_t K) noexcept -> uint32_t + { + return integer_divide_ceil(K, KPerBlock); // Stream-K processes one K-slice at a time + } + + /** + * @brief Get output tile index for standard 2D mapping (compatibility) + */ + CK_TILE_DEVICE auto + GetOutputTileIndex(uint32_t tile_idx) const noexcept -> tuple + { + uint32_t m_tile_idx, n_tile_idx; + n_tiles.divmod(tile_idx, num_tile_n_, m_tile_idx, n_tile_idx); + + // swizzle tile + + uint32_t tile_swizzle_sub_m_rem = num_tile_m_ % TileSwizzleSubM; + + const auto sub_m_adapt = (m_tile_idx < (num_tile_m_ - tile_swizzle_sub_m_rem)) + ? TileSwizzleSubM + : tile_swizzle_sub_m_rem; + + uint32_t m_tile_idx_sub0, m_tile_idx_sub1; + m_tile_idx_sub0 = m_tile_idx / TileSwizzleSubM; + m_tile_idx_sub1 = m_tile_idx % TileSwizzleSubM; + + uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * num_tile_n_; + + uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt; + + n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt; + m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt; + return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * TileSwizzleSubM, + n_tile_idx_with_adapt); + } + + /** + * @brief Get work range for a given block ID + */ + CK_TILE_DEVICE void + GetBlockItr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const noexcept + { + if(block_idx < sk_num_big_blocks) + { + iter_start = block_idx * k_iters_per_big_block; + iter_end = iter_start + k_iters_per_big_block; + } + else if(block_idx < sk_num_blocks) + { + iter_start = (sk_num_big_blocks * k_iters_per_big_block) + + (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1); + iter_end = iter_start + (k_iters_per_big_block - 1); + } + else if(block_idx >= dp_start_block_idx) + { + uint32_t sk_total_iters = GetSkTotalIters(); + uint32_t dp_iters_per_block = k_iters_per_tile.get(); + iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; + iter_end = iter_start + dp_iters_per_block; + } + } + + /** + * @brief Get total number of iterations for sk tiles + */ + CK_TILE_HOST_DEVICE uint32_t GetSkTotalIters() const noexcept + { + uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block + + (sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1); + return sk_total_iters; + } + + /** + * @brief Get total number of sk tiles + */ + CK_TILE_HOST_DEVICE uint32_t GetSkTiles() const noexcept + { + // tiles for sk + uint32_t sk_total_iters = GetSkTotalIters(); + return k_iters_per_tile.div(sk_total_iters); + } + + /** + * @brief Get length of loop iterations for stream-k loop + */ + CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start, + uint32_t iter_end, + uint32_t total_iter_length) const noexcept + { + uint32_t iter_length_mod, iter_length_quo /*unused*/; + k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); + uint32_t total_iter_length_val = static_cast(total_iter_length); + uint32_t current_iter_length = + min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, + total_iter_length_val); + return current_iter_length; + } + + /** + * @brief Get index of tile during a specified iteration + */ + CK_TILE_DEVICE uint32_t GetTileIdx(uint32_t iter) const noexcept + { + return k_iters_per_tile.div(iter); + } + + /** + * @brief Get index of tile during a specified iteration + */ + CK_TILE_DEVICE void + GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept + { + uint32_t tile_idx_val = static_cast(tile_idx); + uint32_t iter_offset_val = static_cast(iter_offset); + k_iters_per_tile.divmod(iter, tile_idx_val, iter_offset_val); + } + + /** + * @brief Calculates the buffer space needed for accumulation + */ + CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSizeForAcc(uint32_t acc_element_bytes) const noexcept + { + static constexpr uint32_t alignment = 128; + uint32_t acc_buffer_bytes = + MPerBlock * NPerBlock * GetTotalAccBuffers() * acc_element_bytes; + return (acc_buffer_bytes + alignment - 1) / alignment * alignment; + } + + /** + * @brief Calculates the buffer space needed for the semaphore + */ + CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSizeForSemaphore() const noexcept + { + return GetSkTiles() * sizeof(uint32_t); + } + + /** + * @brief Calculates the total buffer space needed for accumulation and the semaphore + */ + CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSize(uint32_t acc_element_bytes) const noexcept + { + return GetWorkSpaceSizeForAcc(acc_element_bytes) + GetWorkSpaceSizeForSemaphore(); + } + + /** + * @brief Get location of intersection of tiles for reduction + */ + CK_TILE_HOST_DEVICE uint32_t GetTileIntersections(uint32_t tiles_, + const mdiv& equiv_tiles_) const noexcept + { + uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1); + uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1; + uint32_t quo_, rem_; + equiv_tiles_.divmod(tile_idx_, quo_, rem_); + return quo_ * max_equiv_tiles_ + rem_; + } + + /** + * @brief Calculate the number of tiles needed for the number of sk blocks + */ + CK_TILE_HOST_DEVICE uint32_t GetTilesCoverSkBlock(uint32_t num_sk_blocks_, + uint32_t iters_per_sk_block_) const noexcept + { + return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() - + 1); + } + + /** + * @brief Calculate the amount of total accumulation buffers required for stream-k + */ + CK_TILE_HOST_DEVICE uint32_t GetTotalAccBuffers() const noexcept + { + uint32_t tiles_cover_big_blocks = + GetTilesCoverSkBlock(sk_num_big_blocks, k_iters_per_big_block); + uint32_t tiles_cover_little_blocks = + GetTilesCoverSkBlock(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1); + + uint32_t total_intersec_big = GetTileIntersections(tiles_cover_big_blocks, equiv_tiles_big); + uint32_t total_intersec_little = + GetTileIntersections(tiles_cover_little_blocks, equiv_tiles_little); + + return sk_num_blocks + total_intersec_big + total_intersec_little; + } + + /** + * @brief Calculate offset based on tile index for big/little tiles + */ + CK_TILE_DEVICE uint32_t GetAccBufferOffsetFromTile(uint32_t tile_idx_) const noexcept + { + uint32_t tiles_cover_big_blocks = + GetTilesCoverSkBlock(sk_num_big_blocks, k_iters_per_big_block); + if(tile_idx_ < tiles_cover_big_blocks) + { + uint32_t touched_sk_blocks = + (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) / + k_iters_per_big_block; + uint32_t current_intersec = GetTileIntersections(tile_idx_, equiv_tiles_big); + return touched_sk_blocks + current_intersec; + } + else + { + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + uint32_t tile_idx_little_reverse = GetSkTiles() - tile_idx_; + uint32_t touched_sk_blocks = + (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) / + iters_per_little_sk_block; + uint32_t current_intersec = + GetTileIntersections(tile_idx_little_reverse, equiv_tiles_little); + return GetTotalAccBuffers() - (touched_sk_blocks + current_intersec); + } + } + + /** + * @brief Calculate offset based on block_idx index for big/little streamk blocks + */ + CK_TILE_DEVICE uint32_t GetAccBufferOffsetFromBlock(uint32_t block_idx_) const noexcept + { + uint32_t iters_per_big_sk_block = k_iters_per_big_block; + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + if(block_idx_ < sk_num_big_blocks) + { + uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block + + k_iters_per_tile.get() - 1); + uint32_t current_intersec = GetTileIntersections(touched_tiles, equiv_tiles_big); + return block_idx_ + current_intersec; + } + else + { + uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_; + uint32_t touched_tiles = k_iters_per_tile.div( + block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1); + uint32_t current_intersec = GetTileIntersections(touched_tiles, equiv_tiles_little); + return GetTotalAccBuffers() - (block_idx_little_reverse + current_intersec); + } + } + + // Getters for problem dimensions + CK_TILE_HOST_DEVICE uint32_t GetNumTileM() const noexcept { return num_tile_m_; } + CK_TILE_HOST_DEVICE uint32_t GetNumTileN() const noexcept { return num_tile_n_; } + CK_TILE_HOST_DEVICE uint32_t GetNumTileK() const noexcept { return num_tile_k_; } + + uint32_t sk_num_blocks; + uint32_t sk_num_big_blocks; + uint32_t dp_start_block_idx; + uint32_t reduction_start_block_idx; + uint32_t k_iters_per_big_block; + mdiv2 n_tiles; + mdiv k_iters_per_tile; + mdiv equiv_tiles_big; // for reduction + mdiv equiv_tiles_little; // for reduction + + private: + uint32_t M_, N_, K_; + uint32_t num_tile_m_, num_tile_n_, num_tile_k_; +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index a05e7b2ad0..77c431e49c 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -9,15 +9,6 @@ namespace ck_tile { -enum StreamKReductionStrategy : uint32_t -{ - /// @brief Workgroups atomically add their results to the C tensor - Atomic = 0u, - /// @brief For a given tile in the C tensor, one workgroup accumulates results of other - /// contributing workgroups - Reduction = 1u -}; - /// @brief The Stream K GEMM kernel host arguments. /// /// @par Overview @@ -37,7 +28,7 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<> index_t stride_B_, index_t stride_C_, StreamKReductionStrategy reduction_strategy_, - index_t num_sk_blocks_ = -1) + uint32_t num_sk_blocks_ = 0xffffffff) : UniversalGemmHostArgs<>({a_ptr_}, {b_ptr_}, {/*ds_ptr*/}, @@ -56,7 +47,7 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<> } ck_tile::StreamKReductionStrategy reduction_strategy; - index_t num_sk_blocks; + uint32_t num_sk_blocks; }; template @@ -103,7 +94,7 @@ struct StreamKKernel /// @brief The strategy used by work groups to compute final results in C tensor. StreamKReductionStrategy reduction_strategy; /// @brief The number of stream k blocks. - index_t num_sk_blocks; + uint32_t num_sk_blocks; /// @brief A pointer to a buffer in device memory for accumulating partial via reduction /// strategy. void* workspace_ptr; @@ -152,29 +143,32 @@ struct StreamKKernel CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args) { - index_t occupancy = static_cast(Occupancy()); - index_t num_cu = static_cast(NumCU()); + uint32_t occupancy = static_cast(Occupancy()); + uint32_t num_cu = static_cast(NumCU()); - return StreamKKernelArgs{ - {host_args.as_ptr, - host_args.bs_ptr, - host_args.ds_ptr, - host_args.e_ptr, - host_args.M, - host_args.N, - host_args.K, - host_args.stride_As, - host_args.stride_Bs, - host_args.stride_Ds, - host_args.stride_E, - host_args.k_batch}, - host_args.reduction_strategy, - host_args.num_sk_blocks, - // The workspace pointer is set to nullptr because we must first - // instantiate the TilePartitioner to get the necessary size - /*workspace_ptr =*/nullptr, - TilePartitioner{ - host_args.M, host_args.N, host_args.K, num_cu, occupancy, host_args.num_sk_blocks}}; + return StreamKKernelArgs{{host_args.as_ptr, + host_args.bs_ptr, + host_args.ds_ptr, + host_args.e_ptr, + host_args.M, + host_args.N, + host_args.K, + host_args.stride_As, + host_args.stride_Bs, + host_args.stride_Ds, + host_args.stride_E, + host_args.k_batch}, + host_args.reduction_strategy, + host_args.num_sk_blocks, + // The workspace pointer is set to nullptr because we must first + // instantiate the TilePartitioner to get the necessary size + /*workspace_ptr =*/nullptr, + TilePartitioner{static_cast(host_args.M), + static_cast(host_args.N), + static_cast(host_args.K), + num_cu, + occupancy, + host_args.num_sk_blocks}}; } CK_TILE_HOST static bool From 80ce6a573b4bb37c17c20eaac4fab48666be4edb Mon Sep 17 00:00:00 2001 From: kylasa Date: Wed, 3 Sep 2025 15:32:54 -0700 Subject: [PATCH 015/404] gtest to test atomic_add for a tensor (#2716) * Code drop for gtest to test atomic_add for a tensor * Adding additional test cases * Fix clang errors in CI pipeline * Updated test cases * Fix the Navi card atomic add problem * solved the define problem * add more print out traces * Fix the float4 missing case * solved the gfx9 errors * Address the comment --------- Co-authored-by: Khushbu Co-authored-by: Thomas Ning --- CMakeLists.txt | 4 + test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/atomic_add_op/CMakeLists.txt | 2 + test/ck_tile/atomic_add_op/test_atomic.cpp | 407 +++++++++++++++++++++ test/ck_tile/atomic_add_op/test_atomic.hpp | 115 ++++++ 5 files changed, 529 insertions(+) create mode 100644 test/ck_tile/atomic_add_op/CMakeLists.txt create mode 100755 test/ck_tile/atomic_add_op/test_atomic.cpp create mode 100644 test/ck_tile/atomic_add_op/test_atomic.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 52bb2ccd2d..ddadfb0353 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,6 +234,10 @@ endif() # new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA set(CK_TILE_USE_WMMA 0) +if (SUPPORTED_GPU_TARGETS MATCHES "gfx10") + add_definitions(-DCK_GFX1030_SUPPORT) +endif() + if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") message(STATUS "Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 374e5b4990..695cad19bc 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -23,3 +23,4 @@ add_subdirectory(add_rmsnorm2d_rdquant) add_subdirectory(gemm_block_scale) add_subdirectory(utility) add_subdirectory(reduce) +add_subdirectory(atomic_add_op) diff --git a/test/ck_tile/atomic_add_op/CMakeLists.txt b/test/ck_tile/atomic_add_op/CMakeLists.txt new file mode 100644 index 0000000000..5dfb4d9db3 --- /dev/null +++ b/test/ck_tile/atomic_add_op/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_atomic test_atomic.cpp) +set(CTEST_OUTPUT_ON_FAILURE ON) diff --git a/test/ck_tile/atomic_add_op/test_atomic.cpp b/test/ck_tile/atomic_add_op/test_atomic.cpp new file mode 100755 index 0000000000..d4f8c5a6a5 --- /dev/null +++ b/test/ck_tile/atomic_add_op/test_atomic.cpp @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights +// reserved. + +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "test_atomic.hpp" + +struct AtomicKernelParam +{ + AtomicKernelParam(ck_tile::index_t m_, ck_tile::index_t n_) : m(m_), n(n_) {} + ck_tile::index_t m; + ck_tile::index_t n; +}; + +template +class TestAtomicKernel : public ::testing::TestWithParam> +{ + struct AtomicKernelWaveSize64 + { + using BlockWaves = ck_tile::sequence<2, 1>; + using BlockTile = ck_tile::sequence<128, 8>; + using WaveTile = ck_tile::sequence<64, 8>; + static constexpr ck_tile::index_t kBlockSize = 128; // 2 waves * 64 lanes + }; + + struct AtomicKernelWaveSize32 + { + using BlockWaves = ck_tile::sequence<2, 1>; + using BlockTile = ck_tile::sequence<64, 8>; + using WaveTile = ck_tile::sequence<32, 8>; // 32*2 == 64 + static constexpr ck_tile::index_t kBlockSize = 64; // 2 waves * 32 lanes + }; + + template + void RunTestImpl_(const AtomicKernelParam& params, int require_warp_size, const char* tag) + { + // Device capability check & skip if wavesize mismatches + int dev = 0; + hipDeviceProp_t prop{}; + if(hipGetDevice(&dev) != hipSuccess || hipGetDeviceProperties(&prop, dev) != hipSuccess) + { + GTEST_SKIP() << "[" << tag << "] hipGetDeviceProperties failed; skipping."; + } + if(prop.warpSize != require_warp_size) + { + GTEST_SKIP() << "[" << tag << "] Device warpSize=" << prop.warpSize << " (requires " + << require_warp_size << "); skipping."; + } + + using XDataType = DataType_; + + const ck_tile::index_t m = params.m; + const ck_tile::index_t n = params.n; + + std::cout << "[" << tag << "] Input Tensor Dimensions: " << m << ", " << n << std::endl; + + constexpr int dword_bytes = 4; + const int base_vec = dword_bytes / static_cast(sizeof(XDataType)); + const int vec = multiple_ * base_vec; + + ASSERT_EQ(n % vec, 0) << " Row dimension must be divisible by vector width: n=" << n + << " vec=" << vec << " (multiple=" << multiple_ + << ", base_vec=" << base_vec << ")"; + + // host tensors + ck_tile::HostTensor x_host_ref({m, n}); + ck_tile::HostTensor x_host_dev({m, n}); + + // device buffers + ck_tile::DeviceMem x_dev_input(x_host_dev.get_element_space_size_in_bytes()); + x_dev_input.SetZero(); + x_host_ref.SetZero(); + + using BlockWaves = typename Config::BlockWaves; + using BlockTile = typename Config::BlockTile; + using WaveTile = typename Config::WaveTile; + using Vector = ck_tile::sequence<1, vec>; + + // Compile-time sanity: BlockTile == WaveTile * BlockWaves + static_assert(BlockTile::at(ck_tile::number<0>{}) == + WaveTile::at(ck_tile::number<0>{}) * BlockWaves::at(ck_tile::number<0>{}), + "BlockTile.M must equal WaveTile.M * BlockWaves.M"); + static_assert(BlockTile::at(ck_tile::number<1>{}) == + WaveTile::at(ck_tile::number<1>{}) * BlockWaves::at(ck_tile::number<1>{}), + "BlockTile.N must equal WaveTile.N * BlockWaves.N"); + + std::cout << "[" << tag << "] Vector per thread = " << vec + << " BlockWaves=" << BlockWaves::at(ck_tile::number<0>{}) << "x" + << BlockWaves::at(ck_tile::number<1>{}) + << " WaveTile=" << WaveTile::at(ck_tile::number<0>{}) << "x" + << WaveTile::at(ck_tile::number<1>{}) + << " BlockTile=" << BlockTile::at(ck_tile::number<0>{}) << "x" + << BlockTile::at(ck_tile::number<1>{}) << std::endl; + + const ck_tile::index_t kGridSize = + ck_tile::integer_divide_ceil(m, BlockTile::at(ck_tile::number<0>{})); + + using Shape = ck_tile::AtomicKernelShape; + using Problem = ck_tile::AtomicKernelProblem; + using Kernel = ck_tile::AtomicKernel; + + constexpr ck_tile::index_t kBlockSize = Config::kBlockSize; + constexpr ck_tile::index_t kBlockPerCu = 1; + + (void)hipGetLastError(); // clear sticky + + launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_dev_input.GetDeviceBuffer()), + m, + n)); + + ASSERT_EQ(hipPeekAtLastError(), hipSuccess) + << "[" << tag << "] hipPeekAtLastError: " << hipGetErrorString(hipGetLastError()); + ASSERT_EQ(hipDeviceSynchronize(), hipSuccess) + << "[" << tag << "] hipDeviceSynchronize failed"; + + // host reference computation + x_dev_input.FromDevice(x_host_dev.mData.data()); + for(int i = 0; i < m; ++i) + for(int j = 0; j < n; ++j) + x_host_ref(i, j) = static_cast(1); + + const bool pass = ck_tile::check_err(x_host_dev, x_host_ref); + EXPECT_TRUE(pass); + } + + protected: + // WaveSize = 64 path + void RunTest(const AtomicKernelParam& params) + { + RunTestImpl_(params, /*require_warp_size=*/64, "WS64"); + } + + // WaveSize = 32 path + void RunTestWave32(const AtomicKernelParam& params) + { + RunTestImpl_(params, /*require_warp_size=*/32, "WS32"); + } +}; + +class TestAtomicKernelHalf_1 : public TestAtomicKernel +{ +}; +class TestAtomicKernelHalf_2 : public TestAtomicKernel +{ +}; +class TestAtomicKernelHalf_4 : public TestAtomicKernel +{ +}; +class TestAtomicKernelBF16_1 : public TestAtomicKernel +{ +}; +class TestAtomicKernelBF16_2 : public TestAtomicKernel +{ +}; +class TestAtomicKernelBF16_4 : public TestAtomicKernel +{ +}; +class TestAtomicKernelBF8_1 : public TestAtomicKernel +{ +}; +class TestAtomicKernelBF8_2 : public TestAtomicKernel +{ +}; +class TestAtomicKernelFP8_1 : public TestAtomicKernel +{ +}; +class TestAtomicKernelFP8_2 : public TestAtomicKernel +{ +}; +class TestAtomicKernelFloat_1 : public TestAtomicKernel +{ +}; +class TestAtomicKernelFloat_2 : public TestAtomicKernel +{ +}; +class TestAtomicKernelFloat_4 : public TestAtomicKernel +{ +}; + +// +// WaveSize=64 tests (auto-skip on wave32 devices) +// +#if defined(CK_USE_XDL) +TEST_P(TestAtomicKernelHalf_1, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} +TEST_P(TestAtomicKernelHalf_2, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} +TEST_P(TestAtomicKernelHalf_4, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} +TEST_P(TestAtomicKernelBF16_1, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} +TEST_P(TestAtomicKernelBF16_2, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} +TEST_P(TestAtomicKernelBF16_4, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} +TEST_P(TestAtomicKernelBF8_1, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} +TEST_P(TestAtomicKernelBF8_2, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} +TEST_P(TestAtomicKernelFP8_1, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} +TEST_P(TestAtomicKernelFP8_2, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} +TEST_P(TestAtomicKernelFloat_1, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} +TEST_P(TestAtomicKernelFloat_2, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} +TEST_P(TestAtomicKernelFloat_4, TestCorrectness) +{ + auto [M, N] = GetParam(); + this->RunTest({M, N}); +} + +// +// WaveSize=32 tests (auto-skip on wave64 devices) +// +#else +TEST_P(TestAtomicKernelHalf_1, TestCorrectnessWS32) +{ + auto [M, N] = GetParam(); + this->RunTestWave32({M, N}); +} +TEST_P(TestAtomicKernelHalf_2, TestCorrectnessWS32) +{ + auto [M, N] = GetParam(); + this->RunTestWave32({M, N}); +} +TEST_P(TestAtomicKernelHalf_4, TestCorrectnessWS32) +{ + auto [M, N] = GetParam(); + this->RunTestWave32({M, N}); +} +TEST_P(TestAtomicKernelBF16_1, TestCorrectnessWS32) +{ + auto [M, N] = GetParam(); + this->RunTestWave32({M, N}); +} +TEST_P(TestAtomicKernelBF16_2, TestCorrectnessWS32) +{ + auto [M, N] = GetParam(); + this->RunTestWave32({M, N}); +} +TEST_P(TestAtomicKernelBF16_4, TestCorrectnessWS32) +{ + auto [M, N] = GetParam(); + this->RunTestWave32({M, N}); +} +TEST_P(TestAtomicKernelBF8_1, TestCorrectnessWS32) +{ + auto [M, N] = GetParam(); + this->RunTestWave32({M, N}); +} +TEST_P(TestAtomicKernelBF8_2, TestCorrectnessWS32) +{ + auto [M, N] = GetParam(); + this->RunTestWave32({M, N}); +} +TEST_P(TestAtomicKernelFP8_1, TestCorrectnessWS32) +{ + auto [M, N] = GetParam(); + this->RunTestWave32({M, N}); +} +TEST_P(TestAtomicKernelFP8_2, TestCorrectnessWS32) +{ + auto [M, N] = GetParam(); + this->RunTestWave32({M, N}); +} +TEST_P(TestAtomicKernelFloat_1, TestCorrectnessWS32) +{ + auto [M, N] = GetParam(); + this->RunTestWave32({M, N}); +} +TEST_P(TestAtomicKernelFloat_2, TestCorrectnessWS32) +{ + auto [M, N] = GetParam(); + this->RunTestWave32({M, N}); +} +#endif + +// Common parameter lists +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelHalf_1, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); + +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelHalf_2, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); + +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelHalf_4, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); + +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelBF16_1, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); + +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelBF16_2, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); + +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelBF16_4, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); + +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelBF8_1, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); + +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelBF8_2, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); + +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelFP8_1, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); + +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelFP8_2, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); + +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelFloat_1, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); + +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelFloat_2, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); + +#if defined(CK_USE_XDL) +INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite, + TestAtomicKernelFloat_4, + ::testing::Values(std::tuple{64, 8}, + std::tuple{64, 16}, + std::tuple{64, 32})); +#endif diff --git a/test/ck_tile/atomic_add_op/test_atomic.hpp b/test/ck_tile/atomic_add_op/test_atomic.hpp new file mode 100644 index 0000000000..a6697f824b --- /dev/null +++ b/test/ck_tile/atomic_add_op/test_atomic.hpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +namespace ck_tile { + +template +struct AtomicKernelShape +{ + static constexpr index_t MWarps = BlockWaves::at(number<0>{}); + static constexpr index_t NWarps = BlockWaves::at(number<1>{}); + + static constexpr index_t Block_M = BlockTile::at(number<0>{}); + static constexpr index_t Block_N = BlockTile::at(number<1>{}); + + static constexpr index_t Warp_M = WaveTile::at(number<0>{}); + static constexpr index_t Warp_N = WaveTile::at(number<1>{}); + + static constexpr index_t Vector_M = Vector::at(number<0>{}); + static constexpr index_t Vector_N = Vector::at(number<1>{}); + + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + static constexpr index_t WarpPerBlock_M = MWarps; + static constexpr index_t WarpPerBlock_N = NWarps; + + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + static constexpr index_t WaveNum = reduce_on_sequence(BlockWaves{}, multiplies{}, number<1>{}); + + static constexpr index_t BlockSize = get_warp_size() * WaveNum; +}; + +template +struct AtomicKernelProblem +{ + using XDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +template +struct AtomicKernel +{ + using Problem = remove_cvref_t; + using XDataType = typename Problem::XDataType; + + static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + + template + CK_TILE_DEVICE static constexpr auto MakeTileDistribution() + { + using S = typename Problem::BlockShape; + + constexpr index_t warp_size = get_warp_size(); + + constexpr index_t X0 = S::ThreadPerWarp_N; + constexpr index_t X1 = S::Vector_N; + + constexpr index_t Y0 = S::WaveNum; + constexpr index_t Y2 = warp_size / X0; + constexpr index_t Y1 = S::Warp_M / Y2; + + constexpr auto encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}; + + return make_static_tile_distribution(encoding); + } + + CK_TILE_DEVICE void operator()(XDataType* input, index_t M, index_t N) const + { + using S = typename Problem::BlockShape; + + constexpr auto block_dims = make_tuple(number{}, number{}); + + const index_t iM = __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_M); + + const auto input_view = + make_naive_tensor_view( + input, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + auto input_window = make_tile_window(input_view, block_dims, {iM, 0}); + + const index_t num_iterations = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); + using tmp_tile = + decltype(make_static_distributed_tensor(MakeTileDistribution())); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_iterations; iN++) + { + tmp_tile add_value_tile; + tile_elementwise_inout([](auto& c) { c = static_cast(1.0f); }, + add_value_tile); + + update_tile(input_window, add_value_tile); + __syncthreads(); + + move_tile_window(input_window, {0, S::Block_N}); + } + } +}; + +} // namespace ck_tile From e2d28a92af81139b9743e454570c1ada3146f87c Mon Sep 17 00:00:00 2001 From: linqunAMD Date: Thu, 4 Sep 2025 08:33:40 +0800 Subject: [PATCH 016/404] Extend XDL kernel to Support RDNA3/4 - Part 2 (#2722) Update Blockwise and Gridwise files to support both wave32 & wave64. 1. Calculate WaveSize from template parameter, instead of hard code it to 64, some "64" is also replace with WaveSize 2. Move BN0Shuffled and BK0Shuffled to device side. we can't get correct mfma inst info in host side. 3. Update b_thread_offset_n and b_thread_offset_k in gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp for gfx11. in gfx11, input data is duplicated for each 16 threads, it is different with all of others. 4. Modify a1_threadwise_copy in gridwise_batched_*gemm*gemm for gfx11. for gfx11, we need duplicate input and swizzle A if transposeC isn't enabled. --- .../gpu/block/blockwise_gemm_dpp.hpp | 11 +-- ...blockwise_gemm_mx_pipeline_xdlops_base.hpp | 9 +- .../block/blockwise_gemm_pipeline_xdlops.hpp | 16 +-- ...ipeline_xdlops_b_preshuffle_dequant_v1.hpp | 3 +- ...ipeline_xdlops_b_preshuffle_dequant_v3.hpp | 3 +- ...dlops_b_preshuffle_gufusion_dequant_v1.hpp | 3 +- ...peline_xdlops_b_preshuffle_gufusion_v1.hpp | 3 +- ...peline_xdlops_b_preshuffle_gufusion_v3.hpp | 3 +- ...e_gemm_pipeline_xdlops_b_preshuffle_v1.hpp | 5 +- ...e_gemm_pipeline_xdlops_b_preshuffle_v2.hpp | 5 +- ...e_gemm_pipeline_xdlops_b_preshuffle_v3.hpp | 5 +- .../blockwise_gemm_pipeline_xdlops_base.hpp | 2 + ...line_xdlops_blockscale_b_preshuffle_v1.hpp | 3 +- ...line_xdlops_blockscale_b_preshuffle_v3.hpp | 3 +- ...oe_blockscale_b_preshuffle_gufusion_v1.hpp | 3 +- ...oe_blockscale_b_preshuffle_gufusion_v3.hpp | 3 +- ..._xdlops_moe_blockscale_b_preshuffle_v1.hpp | 3 +- ..._xdlops_moe_blockscale_b_preshuffle_v3.hpp | 3 +- .../block/blockwise_gemm_smfmac_xdlops.hpp | 13 ++- .../gpu/block/blockwise_gemm_xdlops.hpp | 20 ++-- .../blockwise_gemm_xdlops_skip_b_lds.hpp | 9 +- .../element/unary_element_wise_operation.hpp | 28 +++++- ...wise_batched_gemm_gemm_xdl_cshuffle_v1.hpp | 42 ++++++-- ...iple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp | 42 +++++--- ...ultiple_d_softmax_gemm_xdl_cshuffle_v1.hpp | 42 +++++--- ...ched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 47 ++++++--- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 10 +- .../gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 10 +- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 2 +- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 13 ++- ...wise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 23 +++-- .../gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp | 45 ++++++--- ...ridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 10 +- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 10 +- ..._gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 19 ++-- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 28 +++--- ...fle_v3_multi_d_blockscale_b_preshuffle.hpp | 42 ++++---- .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 17 ++-- ...se_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp | 23 ++--- .../gridwise_gemm_xdlops_skip_b_lds_v1.hpp | 4 +- .../gpu/grid/gridwise_moe_gemm.hpp | 98 ++++++++++--------- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 45 +++++---- .../gpu/grid/gridwise_moe_mx_gemm.hpp | 18 ++-- .../gpu/grid/gridwise_moe_mx_gemm_bns.hpp | 18 ++-- .../grid/gridwise_moe_mx_gemm_bpreshuffle.hpp | 31 +++--- .../threadwise_tensor_slice_transfer.hpp | 2 +- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 72 +++++++++++--- include/ck/utility/type_convert.hpp | 36 ++++++- 48 files changed, 605 insertions(+), 300 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp index f03427a7ea..5012e53d33 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -38,13 +38,15 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t WaveSize = get_warp_size(); - static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); static constexpr index_t KPerBlock = BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp); + static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); @@ -54,9 +56,6 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp; - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp); - StaticBufferTupleOfVector; // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs. - static constexpr index_t WaveSize = 64; + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); @@ -77,9 +79,6 @@ struct BlockwiseGemmXdlops_mx_pipeline_base static constexpr index_t KRepeat = KPerThread / KPack; static constexpr index_t KPerInnerLoop = KPack; - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - // Hardcode to 2, for better 8-bit access pattern static constexpr index_t MXdlPack = 2; @@ -206,6 +205,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base Tuple5 b_origin = CalculateBThreadOriginDataIndex()) : a_thread_copy_(a_origin), b_thread_copy_(b_origin) { +#if defined(__HIP_DEVICE_COMPILE__) static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, @@ -213,6 +213,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, "wrong!"); +#endif } // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp index 231dbf817c..613886453b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -32,9 +32,9 @@ template struct BlockwiseGemmXdlops_pipeline_hotloop_inst { - static constexpr index_t WaveSize = 64; static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL); + static constexpr index_t WaveSize = BlockSize / (WaveNumM * WaveNumN); static constexpr index_t A_Buffer_Load_Inst_Num = MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); @@ -108,7 +108,11 @@ struct BlockwiseGemmXdlops_pipeline_v4 using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t WaveSize = get_warp_size(); + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + static_assert(MWaves > 0); + static_assert(NWaves > 0); + static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); @@ -121,9 +125,6 @@ struct BlockwiseGemmXdlops_pipeline_v4 static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KRepeat = KPerThread / KPack; - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - using HotLoopInstList = BlockwiseGemmXdlops_pipeline_hotloop_inst{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp index 8b227a8aa1..b0a583030e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp @@ -143,6 +143,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}; @@ -159,7 +160,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp index 29750b8baa..fc7360dee5 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -142,6 +142,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< using Base::AMmaKStride; using Base::BMmaKStride; using Base::c_thread_desc_; + using Base::WaveSize; static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefillStages = 1; @@ -154,7 +155,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp index fe89e700c4..68cde2b880 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp @@ -144,6 +144,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack / KGroup; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp index c76be74e52..fa67750e87 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp @@ -145,6 +145,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack / KGroup; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index d8f11572a8..437c73fa97 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -143,6 +143,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack / KGroup; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp index 601756be44..03f50fe17a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -142,6 +142,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack / KGroup; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp index 6f0404a1ca..3d725fc5fa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -151,6 +151,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack / KGroup; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index acd1d2ae49..ff64b6fe2a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -93,8 +93,10 @@ struct BlockwiseGemmXdlops_pipeline_base NPerXDL, xdlops_gemm.KPerXdlops>; +#if defined(__HIP_DEVICE_COMPILE__) static_assert(KPerThread % KPack == 0, "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); +#endif StaticBufferTupleOfVector{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack / KGroup; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp index cc4c5a2c36..ba5b6e8292 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp @@ -152,6 +152,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack / KGroup; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp index 1608506b40..3bec4821fb 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp @@ -153,6 +153,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< using Base::MWaves; using Base::NWaves; + using Base::WaveSize; static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefillStages = 1; @@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack / KGroup; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp index 30d6d4f812..7e33016cec 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp @@ -152,6 +152,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; using Base::MWaves; + using Base::WaveSize; static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefillStages = 1; @@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack / KGroup; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp index 598b69cd61..5a21e41c57 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp @@ -153,6 +153,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< using Base::MWaves; using Base::NWaves; + using Base::WaveSize; static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefillStages = 1; @@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack / KGroup; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp index 6db02d1dd7..99bca30dd9 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp @@ -152,6 +152,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; using Base::MWaves; + using Base::WaveSize; static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefillStages = 1; @@ -165,7 +166,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); constexpr index_t K2 = KPack / KGroup; - constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K1 = WaveSize / NPerXDL; constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp index 90f356987d..c553a57672 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -49,7 +49,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t WaveSize = get_warp_size(); + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + static_assert(MWaves > 0); + static_assert(NWaves > 0); + static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); @@ -66,9 +70,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - StaticBufferTupleOfVector; - static constexpr index_t WaveSize = get_warp_size(); - static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); static constexpr index_t KPerBlock = @@ -61,14 +59,15 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; + static constexpr auto xdlops_gemm = XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - StaticBufferTupleOfVector; - static constexpr index_t WaveSize = get_warp_size(); + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); @@ -691,9 +692,6 @@ struct BlockwiseGemmXdlops_v2 static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - static_assert(KPerThread % KPack == 0, "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); @@ -790,6 +788,7 @@ struct BlockwiseGemmXdlops_v2 Tuple4 b_origin = CalculateBThreadOriginDataIndex()) : a_thread_copy_(a_origin), b_thread_copy_(b_origin) { +#if defined(__HIP_DEVICE_COMPILE__) static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); @@ -798,6 +797,7 @@ struct BlockwiseGemmXdlops_v2 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, "wrong!"); +#endif } // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index 84ee096cba..33457f4b0a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -30,8 +30,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr index_t WaveSize = 64; - static constexpr index_t KPerBlock = K0PerBlock * KPack; static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); @@ -42,8 +40,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; StaticBufferTupleOfVector> 16); + fp8x4_0 = __builtin_amdgcn_cvt_pk_fp8_f32(f32_0, f32_1, 0, 0); + float f32_2 = amd_assemble_cvt_f32_i4(q >> 8); + float f32_3 = amd_assemble_cvt_f32_i4(q >> 24); + fp8x4_1 = __builtin_amdgcn_cvt_pk_fp8_f32(f32_2, f32_3, 0, 0); + q = q >> 4; + f32_0 = amd_assemble_cvt_f32_i4(q); + f32_1 = amd_assemble_cvt_f32_i4(q >> 16); + fp8x4_0 = __builtin_amdgcn_cvt_pk_fp8_f32(f32_0, f32_1, fp8x4_0, 1); + f32_2 = amd_assemble_cvt_f32_i4(q >> 8); + f32_3 = amd_assemble_cvt_f32_i4(q >> 24); + fp8x4_1 = __builtin_amdgcn_cvt_pk_fp8_f32(f32_2, f32_3, fp8x4_1, 1); + return bit_cast(((static_cast(fp8x4_1) << 32) | fp8x4_0)); +#elif defined(__gfx11__) + ignore = q; + return f8x8_t{}; +#else + return amd_assembly_i4_to_fp8x8(q); +#endif +} __device__ inline bhalf4_t i4_to_bhalf4(int q) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index 258d0ad0ca..70c641531b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -549,16 +549,30 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle constexpr auto A1ThreadSlice_K0_M_K1 = make_tuple(Number{}, Number{}, Number{}); - constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; - constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; - constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + +#if defined(__gfx11__) + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed( + make_tuple(A1ThreadSliceK0, A1ThreadSliceM, Number{})); + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + FloatGemmAcc, + FloatAB, + decltype(acc_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + decltype(acc_element_op), + Sequence, + Sequence<1, 0, 2>, + 2, + n4, + 0x76543210, + 0xfedcba98, + true>{make_tuple(0, 0, 0)}; +#else constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( A1ThreadSlice_K0_M_K1, make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); - - // B1 matrix in LDS memory, dst of blockwise copy - constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A1 matrix blockwise copy auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< FloatGemmAcc, @@ -570,6 +584,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle Sequence<1, 0, 2>, 2, n4>{acc_element_op}; +#endif + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // B1 matrix blockwise copy auto b1_blockwise_copy = @@ -619,10 +636,15 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size - +#if defined(__gfx11__) + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size * 2; + static_assert( + XdlopsGemm{}.K0PerXdlops == 1); +#else constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; - +#endif auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, FloatAB, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index e8f8caa10d..84d7b04495 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -99,7 +99,6 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle static constexpr auto I6 = Number<6>{}; static constexpr auto I7 = Number<7>{}; - static constexpr auto WaveSize = 64; // K1 should be Number<...> // Gemm0 static constexpr auto A0K1 = Number{}; @@ -110,6 +109,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle static constexpr auto Gemm0MWaves = Gemm0MPerBlock / (Gemm0MPerXdl * Gemm0MXdlPerWave); static constexpr auto Gemm0NWaves = Gemm0NPerBlock / (Gemm0NPerXdl * Gemm0NXdlPerWave); + static constexpr auto WaveSize = BlockSize / (Gemm0MWaves * Gemm0NWaves); // Gemm1 static constexpr auto B1K1 = Number{}; static constexpr auto B1K0PerBlock = Number{}; @@ -824,16 +824,30 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle constexpr auto A1ThreadSlice_K0_M_K1 = make_tuple( Number{}, Number{}, Number{}); - constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; - constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; - constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; +#if defined(__gfx11__) + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed( + make_tuple(A1ThreadSliceK0, A1ThreadSliceM, Number{})); + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + Acc0DataType, + A0B0B1DataType, + decltype(acc0_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<1, 0, 2>, + 2, + n4, + 0x76543210, + 0xfedcba98, + true>{make_tuple(0, 0, 0)}; +#else constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( A1ThreadSlice_K0_M_K1, make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); - // B1 matrix in LDS memory, dst of blockwise copy - constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A1 matrix blockwise copy auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< Acc0DataType, @@ -845,7 +859,10 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle Sequence<1, 0, 2>, 2, n4>{tensor_operation::element_wise::PassThrough{}}; +#endif + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // B1 matrix blockwise copy auto b1_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1::selected_mfma.group_size * 2; +#else constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; - +#endif auto blockwise_gemm1 = BlockwiseGemmXdlops_v2< BlockSize, A0B0B1DataType, @@ -987,7 +1007,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle else { static_for<0, acc0_thread_buf.Size(), 1>{}( - [&](auto i) { cde0_element_op(acc_thread_buf(i), acc0_thread_buf[i]); }); + [&](auto i) { cde0_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); }); } // gemm1 { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 0f2085525f..222cb3894c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -727,16 +727,30 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle constexpr auto A1ThreadSlice_K0_M_K1 = make_tuple(Number{}, Number{}, Number{}); - constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; - constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; - constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; +#if defined(__gfx11__) + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed( + make_tuple(A1ThreadSliceK0, A1ThreadSliceM, Number{})); + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + FloatGemmAcc, + FloatAB, + decltype(acc_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<1, 0, 2>, + 2, + n4, + 0x76543210, + 0xfedcba98, + false>{make_tuple(0, 0, 0)}; +#else constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( A1ThreadSlice_K0_M_K1, make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); - // B1 matrix in LDS memory, dst of blockwise copy - constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A1 matrix blockwise copy auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< FloatGemmAcc, @@ -748,7 +762,9 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle Sequence<1, 0, 2>, 2, n4>{tensor_operation::element_wise::PassThrough{}}; - +#endif + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // B1 matrix blockwise copy auto b1_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1::selected_mfma.group_size * 2; +#else constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; +#endif auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, @@ -873,8 +893,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle decltype(thread_slice_desc_m_n)>{}; const index_t num_gemm1_k_block_outer_loop = - b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; - constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; + b_grid_desc_bk0_n_bk1.GetLength(I1) / (NPerBlock / Gemm0NWaves); + constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm0NWaves / Gemm1KPerBlock; // Initialize C StaticBuffer diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index 33b9199ea5..2d00daf7f6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -570,17 +570,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle constexpr auto A1ThreadSlice_K0_M_K1 = make_tuple(Number{}, Number{}, Number{}); - constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; - constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; - constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + + // A1 matrix blockwise copy +#if defined(__gfx11__) + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed( + make_tuple(A1ThreadSliceK0, A1ThreadSliceM, Number{})); + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + FloatGemmAcc, + FloatAB, + decltype(acc_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<1, 0, 2>, + 2, + n4, + 0x76543210, + 0xfedcba98, + false>{make_tuple(0, 0, 0)}; + static_assert(n4 == A1ThreadSliceK1); +#else constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( A1ThreadSlice_K0_M_K1, make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); - - // B1 matrix in LDS memory, dst of blockwise copy - constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // A1 matrix blockwise copy auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< FloatGemmAcc, FloatAB, @@ -591,6 +606,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle Sequence<1, 0, 2>, 2, n4>{tensor_operation::element_wise::PassThrough{}}; +#endif + + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // B1 matrix blockwise copy auto b1_blockwise_copy = @@ -640,9 +659,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // therefore we may just as well assign Gemm1KPack = group_size - +#if defined(__gfx11__) + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size * 2; +#else constexpr index_t Gemm1KPack = MfmaSelector::selected_mfma.group_size; +#endif auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, @@ -716,8 +739,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle decltype(thread_slice_desc_m_n)>{}; const index_t num_gemm1_k_block_outer_loop = - b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; - constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; + b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock * Gemm0NWaves; + constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock / Gemm0NWaves; // Initialize C StaticBuffer diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 68112489ca..638f64981f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -269,6 +269,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM) { @@ -325,7 +328,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) @@ -406,6 +409,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN) { @@ -459,7 +465,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto KThreadRead = WaveSize / NPerXdl; constexpr auto K0PerThreadRead = BK0Number / KThreadRead; constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index 129929b665..7cb0fd3338 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -656,6 +656,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM) { @@ -712,7 +715,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) @@ -793,6 +796,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN) { @@ -846,7 +852,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto KThreadRead = WaveSize / NPerXdl; constexpr auto K0PerThreadRead = BK0Number / KThreadRead; constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index e4d5b99ffe..906bfe0912 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 4c07d60b0f..5545192e3c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -818,10 +818,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves); - + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -960,9 +959,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves); + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 7947d2490a..b99113ef16 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -347,7 +347,9 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -567,9 +569,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle AK0{CalculateAK0Padded(K_, KBatch_)}, BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)}, - BN0Shuffled{CalculateBN0Shuffled(N_)}, - BK0Shuffled{CalculateBK0Shuffled(K_)} + NBlock{CalculateNBlock(N_)} { } @@ -598,9 +598,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle index_t BK0; index_t MBlock; index_t NBlock; - // For B pre-shuffle only - index_t BN0Shuffled; - index_t BK0Shuffled; }; // Argument @@ -700,6 +697,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM) { @@ -735,7 +734,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) @@ -1460,10 +1459,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle void* p_shared, const Problem& problem) { + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = @@ -1842,10 +1843,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle void* p_shared_1, const Problem& problem) { + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index a7d7546b1c..a9c7556130 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -709,6 +709,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -764,7 +767,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) @@ -845,6 +848,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -896,7 +902,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto KThreadRead = WaveSize / NPerXdl; constexpr auto K0PerThreadRead = BK0Number / KThreadRead; constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) @@ -1444,12 +1450,18 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - - auto b_thread_offset_n = - get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl; - auto b_thread_offset_k = (get_thread_local_1d_id() % 64) / NPerXdl * KPerThread; - + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); +#if defined(__gfx11__) + auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + + (get_thread_local_1d_id() / WaveSize) % NWaves * NPerXdl; + auto b_thread_offset_k = (get_thread_local_1d_id() % 16) / NPerXdl * KPerThread; +#else + auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + + (get_thread_local_1d_id() / WaveSize) % NWaves * NPerXdl; + auto b_thread_offset_k = (get_thread_local_1d_id() % WaveSize) / NPerXdl * KPerThread; +#endif auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2{}, Number{})); - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - - auto b_thread_offset_n = - get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl; - auto b_thread_offset_k = (get_thread_local_1d_id() % 64) / NPerXdl * KPerThread; + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); +#if defined(__gfx11__) + auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + + (get_thread_local_1d_id() / WaveSize) % NWaves * NPerXdl; + auto b_thread_offset_k = (get_thread_local_1d_id() % 16) / NPerXdl * KPerThread; +#else + auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + + (get_thread_local_1d_id() / WaveSize) % NWaves * NPerXdl; + auto b_thread_offset_k = (get_thread_local_1d_id() % WaveSize) / NPerXdl * KPerThread; +#endif auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2 128) @@ -906,6 +909,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN) { @@ -959,7 +965,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto KThreadRead = WaveSize / NPerXdl; constexpr auto K0PerThreadRead = BK0Number / KThreadRead; constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index b72c4d0313..676da3e925 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -689,6 +689,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -747,7 +750,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) @@ -828,6 +831,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -883,7 +889,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto KThreadRead = WaveSize / NPerXdl; constexpr auto K0PerThreadRead = BK0Number / KThreadRead; constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index e80a3702fb..be3c6ebb35 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -685,6 +685,9 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM) { @@ -720,7 +723,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) @@ -801,6 +804,9 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN) { @@ -831,7 +837,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto KThreadRead = WaveSize / NPerXdl; constexpr auto K0PerThreadRead = BK0Number / KThreadRead; constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) @@ -1358,10 +1364,11 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - auto a_thread_offset = - get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); + auto a_thread_offset = get_thread_local_1d_id() % MPerXdl + + (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl; constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 373d4eb4e4..dfcc20b3c2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -375,7 +375,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -591,9 +593,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle AK0{CalculateAK0Padded(K_, KBatch_)}, BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)}, - BN0Shuffled{CalculateBN0Shuffled(N_)}, - BK0Shuffled{CalculateBK0Shuffled(K_)} + NBlock{CalculateNBlock(N_)} { } @@ -604,7 +604,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + << "NBlock: " << NBlock << " }" << std::endl; } index_t M; @@ -623,9 +623,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle index_t BK0; index_t MBlock; index_t NBlock; - // FOR PRESHUFFLE ONLY - index_t BN0Shuffled; - index_t BK0Shuffled; }; // Argument @@ -714,6 +711,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM) { @@ -749,7 +748,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) @@ -1144,12 +1143,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle CElementwiseOperation c_element_op, const Block2CTileMap& block_2_ctile_map) { - ignore = b_element_op; + ignore = b_element_op; + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); @@ -1578,11 +1580,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const Block2CTileMap& block_2_ctile_map) { ignore = b_element_op; + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp index e345bc860b..d832bef2da 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -370,7 +370,9 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -547,9 +549,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle AK0{CalculateAK0Padded(K_, KBatch_)}, BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)}, - BN0Shuffled{CalculateBN0Shuffled(N_)}, - BK0Shuffled{CalculateBK0Shuffled(K_)} + NBlock{CalculateNBlock(N_)} { } @@ -579,9 +579,6 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle index_t BK0; index_t MBlock; index_t NBlock; - // FOR PRESHUFFLE ONLY - index_t BN0Shuffled; - index_t BK0Shuffled; }; // Argument @@ -676,6 +673,8 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM) { @@ -711,7 +710,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) @@ -1085,11 +1084,13 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle CElementwiseOperation c_element_op) { ignore = b_element_op; + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); @@ -1227,10 +1228,12 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - auto a_thread_offset = - get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); + + auto a_thread_offset = get_thread_local_1d_id() % MPerXdl + + (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl; constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); @@ -1580,10 +1583,12 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle CElementwiseOperation c_element_op) { ignore = b_element_op; + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); @@ -1727,10 +1732,11 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - auto a_thread_offset = - get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); + auto a_thread_offset = get_thread_local_1d_id() % MPerXdl + + (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl; constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index bc87559c43..cb9c354701 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -807,6 +807,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -838,9 +842,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // kfold and mpair dimension is not always required. // more dimension in merge_transform increase the difficulty of generating immarg offset // for compiler. - constexpr auto WaveSize = 64; - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; @@ -925,6 +928,9 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -952,9 +958,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 } else // RowMajor B { - constexpr auto WaveSize = 64; - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index 7902a16fb3..3ac9845b66 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -422,7 +422,9 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor_packed( make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); } @@ -666,9 +668,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle AK0{CalculateAK0Padded(K_, KBatch_)}, BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)}, - BN0Shuffled{CalculateBN0Shuffled(N_)}, - BK0Shuffled{CalculateBK0Shuffled(K_)} + NBlock{CalculateNBlock(N_)} { } @@ -700,9 +700,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle index_t BK0; index_t MBlock; index_t NBlock; - // FOR PRESHUFFLE ONLY - index_t BN0Shuffled; - index_t BK0Shuffled; }; // Argument @@ -836,6 +833,9 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -867,9 +867,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle // kfold and mpair dimension is not always required. // more dimension in merge_transform increase the difficulty of generating immarg offset // for compiler. - constexpr auto WaveSize = 64; - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; @@ -2271,10 +2270,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle void* p_shared_1, const Problem& problem) { + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bk0_n_bk1 = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index 24fe81c74e..0dbdac85bf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -119,9 +119,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 // K1 should be Number<...> static constexpr auto K1 = Number{}; - static constexpr index_t WaveSize = 64; static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL); static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL); + static constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); static constexpr auto xdlops_gemm = XdlopsGemm{}; static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 48ccb49db4..b0a606cf38 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -405,7 +405,9 @@ struct GridwiseMoeGemm __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -611,9 +613,7 @@ struct GridwiseMoeGemm AK0{CalculateAK0Padded(K_, KBatch_)}, BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)}, - BN0Shuffled{CalculateBN0Shuffled(N_)}, - BK0Shuffled{CalculateBK0Shuffled(K_)} + NBlock{CalculateNBlock(N_)} { } @@ -646,9 +646,6 @@ struct GridwiseMoeGemm index_t BK0; index_t MBlock; index_t NBlock; - // FOR PRESHUFFLE ONLY - index_t BN0Shuffled; - index_t BK0Shuffled; }; // Argument @@ -757,6 +754,9 @@ struct GridwiseMoeGemm __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM) { @@ -792,7 +792,7 @@ struct GridwiseMoeGemm constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) @@ -1164,6 +1164,8 @@ struct GridwiseMoeGemm CElementwiseOperation c_element_op) { ignore = b_element_op; + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, @@ -1172,7 +1174,7 @@ struct GridwiseMoeGemm problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, problem.MPadded, @@ -1418,7 +1420,7 @@ struct GridwiseMoeGemm const float* p_scale_b = p_ds_grid[I1]; static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); - static_assert(M4 == 4); + static_assert(M4 == 4 || M4 == 8); const index_t m1 = get_warp_local_1d_id() / NWave; const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; @@ -1435,8 +1437,8 @@ struct GridwiseMoeGemm p_scale_b += expert_id; } - vector_type scale_token_ids; - vector_type topk_weights; + vector_type scale_token_ids; + vector_type topk_weights; static_for<0, NXdlPerWave, 1>{}([&](auto n0) { const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave @@ -1458,7 +1460,8 @@ struct GridwiseMoeGemm float scale_a = [&]() { if constexpr(PerTokenQuant) { - index_t fused_token = scale_token_ids.AsType()[m4]; + index_t fused_token = + scale_token_ids.template AsType()[m4]; const index_t token_offset = fused_token & 0xffffff; return token_offset < problem.NumTokens ? p_sorted_weights_0[IsInputGemm @@ -1489,8 +1492,8 @@ struct GridwiseMoeGemm float up = scale_a * scale_up * c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } if constexpr(is_same_v, pk_i4_t>) { @@ -1509,8 +1512,8 @@ struct GridwiseMoeGemm float up = scale_a * scale_up * c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } if constexpr(is_same_v, pk_i4_t>) { @@ -1527,8 +1530,9 @@ struct GridwiseMoeGemm scale_a * scale_b * c_thread_buf[cidx]; if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) * - topk_weights.AsType()[m4]; + c_thread_buf_fp32(cidx) = + c_thread_buf_fp32(cidx) * + topk_weights.template AsType()[m4]; } } }); @@ -1538,7 +1542,7 @@ struct GridwiseMoeGemm } else { - vector_type topk_weights; // for gemm2 only + vector_type topk_weights; // for gemm2 only static_for<0, NXdlPerWave, 1>{}([&](auto n0) { static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk @@ -1563,8 +1567,8 @@ struct GridwiseMoeGemm float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; @@ -1575,8 +1579,8 @@ struct GridwiseMoeGemm float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; @@ -1587,8 +1591,9 @@ struct GridwiseMoeGemm c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = topk_weights.AsType()[m4] * - c_thread_buf_fp32[cidx]; + c_thread_buf_fp32(cidx) = + topk_weights.template AsType()[m4] * + c_thread_buf_fp32[cidx]; } } }); @@ -1874,6 +1879,8 @@ struct GridwiseMoeGemm CElementwiseOperation c_element_op) { ignore = b_element_op; + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, @@ -1882,7 +1889,7 @@ struct GridwiseMoeGemm problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, problem.MPadded, @@ -2136,7 +2143,7 @@ struct GridwiseMoeGemm const float* p_scale_b = p_ds_grid[I1]; static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); - static_assert(M4 == 4); + static_assert(M4 == 4 || M4 == 8); const index_t m1 = get_warp_local_1d_id() / NWave; const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; @@ -2153,8 +2160,8 @@ struct GridwiseMoeGemm p_scale_b += expert_id; } - vector_type scale_token_ids; - vector_type topk_weights; + vector_type scale_token_ids; + vector_type topk_weights; static_for<0, NXdlPerWave, 1>{}([&](auto n0) { const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave @@ -2176,7 +2183,8 @@ struct GridwiseMoeGemm float scale_a = [&]() { if constexpr(PerTokenQuant) { - index_t fused_token = scale_token_ids.AsType()[m4]; + index_t fused_token = + scale_token_ids.template AsType()[m4]; const index_t token_offset = fused_token & 0xffffff; return token_offset < problem.NumTokens ? p_sorted_weights_0[IsInputGemm @@ -2207,8 +2215,8 @@ struct GridwiseMoeGemm float up = scale_a * scale_up * c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } if constexpr(is_same_v, pk_i4_t>) { @@ -2227,8 +2235,8 @@ struct GridwiseMoeGemm float up = scale_a * scale_up * c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } if constexpr(is_same_v, pk_i4_t>) { @@ -2245,8 +2253,9 @@ struct GridwiseMoeGemm scale_a * scale_b * c_thread_buf[cidx]; if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) * - topk_weights.AsType()[m4]; + c_thread_buf_fp32(cidx) = + c_thread_buf_fp32(cidx) * + topk_weights.template AsType()[m4]; } } }); @@ -2256,7 +2265,7 @@ struct GridwiseMoeGemm } else { - vector_type topk_weights; // for gemm2 only + vector_type topk_weights; // for gemm2 only static_for<0, NXdlPerWave, 1>{}([&](auto n0) { static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk @@ -2281,8 +2290,8 @@ struct GridwiseMoeGemm float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; @@ -2293,8 +2302,8 @@ struct GridwiseMoeGemm float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - gate = gate * topk_weights.AsType()[m4]; - up = up * topk_weights.AsType()[m4]; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } tensor_operation::element_wise::Gelu{}(gate, gate); c_thread_buf_fp32(cidx) = gate * up; @@ -2305,8 +2314,9 @@ struct GridwiseMoeGemm c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = topk_weights.AsType()[m4] * - c_thread_buf_fp32[cidx]; + c_thread_buf_fp32(cidx) = + topk_weights.template AsType()[m4] * + c_thread_buf_fp32[cidx]; } } }); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 0d78957b07..a8b759da38 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -410,7 +410,9 @@ struct GridwiseMoeGemmBlockScale __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -618,9 +620,7 @@ struct GridwiseMoeGemmBlockScale AK0{CalculateAK0Padded(K_, KBatch_)}, BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)}, - BN0Shuffled{CalculateBN0Shuffled(N_)}, - BK0Shuffled{CalculateBK0Shuffled(K_)} + NBlock{CalculateNBlock(N_)} { } @@ -653,9 +653,6 @@ struct GridwiseMoeGemmBlockScale index_t BK0; index_t MBlock; index_t NBlock; - // FOR PRESHUFFLE ONLY - index_t BN0Shuffled; - index_t BK0Shuffled; }; // Argument @@ -771,6 +768,8 @@ struct GridwiseMoeGemmBlockScale __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM) { @@ -806,7 +805,7 @@ struct GridwiseMoeGemmBlockScale constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) @@ -1183,6 +1182,8 @@ struct GridwiseMoeGemmBlockScale CElementwiseOperation c_element_op) { ignore = b_element_op; + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, @@ -1191,7 +1192,7 @@ struct GridwiseMoeGemmBlockScale problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, problem.MPadded, @@ -1374,10 +1375,11 @@ struct GridwiseMoeGemmBlockScale constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - auto a_thread_offset = - get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); + auto a_thread_offset = get_thread_local_1d_id() % MPerXdl + + (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl; constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); @@ -1578,7 +1580,7 @@ struct GridwiseMoeGemmBlockScale static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); static_assert(M0 * M1 * M2 == MPerBlock); - static_assert(N4 == 4); + static_assert(N4 == 4 || N4 == 8); const index_t m1 = get_warp_local_1d_id() / NWave; const index_t m2 = threadIdx.x % get_warp_size() % M2; @@ -1929,6 +1931,8 @@ struct GridwiseMoeGemmBlockScale CElementwiseOperation c_element_op) { ignore = b_element_op; + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, @@ -1937,7 +1941,7 @@ struct GridwiseMoeGemmBlockScale problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, problem.MPadded, @@ -2126,10 +2130,11 @@ struct GridwiseMoeGemmBlockScale constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - auto a_thread_offset = - get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWaves * NWaves); + auto a_thread_offset = get_thread_local_1d_id() % MPerXdl + + (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl; constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); @@ -2321,7 +2326,7 @@ struct GridwiseMoeGemmBlockScale static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); static_assert(M0 * M1 * M2 == MPerBlock); - static_assert(N4 == 4); + static_assert(N4 == 4 || N4 == 8); const index_t m1 = get_warp_local_1d_id() / NWave; const index_t m2 = threadIdx.x % get_warp_size() % M2; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index ac3a887155..34fcf0e935 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -850,6 +850,10 @@ struct GridwiseMoeGemmMX __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -881,9 +885,8 @@ struct GridwiseMoeGemmMX // kfold and mpair dimension is not always required. // more dimension in merge_transform increase the difficulty of generating immarg offset // for compiler. - constexpr auto WaveSize = 64; - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; @@ -968,6 +971,10 @@ struct GridwiseMoeGemmMX __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -995,9 +1002,8 @@ struct GridwiseMoeGemmMX } else // RowMajor B { - constexpr auto WaveSize = 64; - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index a8417b2e02..3a7b35683d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -783,6 +783,10 @@ struct GridwiseMoeGemmMXBNS __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -813,9 +817,8 @@ struct GridwiseMoeGemmMXBNS // kfold and mpair dimension is not always required. // more dimension in merge_transform increase the difficulty of generating immarg offset // for compiler. - constexpr auto WaveSize = 64; - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; @@ -900,6 +903,10 @@ struct GridwiseMoeGemmMXBNS __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -926,9 +933,8 @@ struct GridwiseMoeGemmMXBNS } else // RowMajor B { - constexpr auto WaveSize = 64; - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp index 46e9a19ae6..3c4f7a24c7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -270,11 +270,11 @@ struct GridwiseMoeGemmMX_BPreshuffle return math::integer_least_multiple(N, NPerBlock); } - __host__ static auto CalculateBN0Shuffled(index_t N) + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) { return math::integer_divide_ceil(N, NLane); } - __host__ static auto CalculateBK0Shuffled(index_t K) + __host__ __device__ static auto CalculateBK0Shuffled(index_t K) { return math::integer_divide_ceil(K, KLane * KPack); } @@ -467,7 +467,9 @@ struct GridwiseMoeGemmMX_BPreshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor_packed( make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); } @@ -700,9 +702,7 @@ struct GridwiseMoeGemmMX_BPreshuffle AK0{CalculateAK0Padded(K_, KBatch_)}, BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)}, - BN0Shuffled{CalculateBN0Shuffled(N_)}, - BK0Shuffled{CalculateBK0Shuffled(K_)} + NBlock{CalculateNBlock(N_)} { } @@ -738,9 +738,6 @@ struct GridwiseMoeGemmMX_BPreshuffle index_t BK0; index_t MBlock; index_t NBlock; - // FOR PRESHUFFLE ONLY - index_t BN0Shuffled; - index_t BK0Shuffled; }; // Argument @@ -869,6 +866,9 @@ struct GridwiseMoeGemmMX_BPreshuffle __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { @@ -900,9 +900,8 @@ struct GridwiseMoeGemmMX_BPreshuffle // kfold and mpair dimension is not always required. // more dimension in merge_transform increase the difficulty of generating immarg offset // for compiler. - constexpr auto WaveSize = 64; - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; @@ -1292,6 +1291,8 @@ struct GridwiseMoeGemmMX_BPreshuffle CElementwiseOperation c_element_op) { ignore = b_element_op; + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, @@ -1300,7 +1301,7 @@ struct GridwiseMoeGemmMX_BPreshuffle problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, problem.MPadded, @@ -1998,6 +1999,8 @@ struct GridwiseMoeGemmMX_BPreshuffle { ignore = a_element_op; ignore = b_element_op; + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, @@ -2006,7 +2009,7 @@ struct GridwiseMoeGemmMX_BPreshuffle problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = - MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, problem.MPadded, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 2305997f70..5da9722a4b 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1900,6 +1900,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow const DstSliceOriginIdx&, DstBuffer& dst_buf) const { + ElementwiseOperation element_op_{}; static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! Desc need to known at compile-time"); @@ -1985,7 +1986,6 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow }); }); } - ElementwiseOperation element_op_{}; }; // Specialized for gfx12 diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 0125aa086e..deea6ae9cc 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1384,25 +1384,31 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_unsupport_16x16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; +#endif } template <> @@ -1432,48 +1438,84 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_unsupport_16x16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; +#endif } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_unsupport_16x16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; +#endif } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_unsupport_16x16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; +#endif } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_unsupport_16x16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; +#endif } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_unsupport_16x16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; +#endif } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { +#if defined(__gfx12__) + return MfmaInstr::wmma_unsupport_16x16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#else return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; +#endif } template <> @@ -1852,7 +1894,7 @@ struct XdlopsGemm Sequence<8>{})); } - __device__ static constexpr index_t GetRegSizePerXdlops() + __device__ __host__ static constexpr index_t GetRegSizePerXdlops() { return MPerXdlops * NPerXdlops / mfma_instr.wave_size; } @@ -1961,7 +2003,7 @@ struct XdlopsGemm { const auto laneId = GetLaneId(); #if defined(__gfx11__) - const auto blk_idx = GetGfx11InputBlkIdx(); + const auto blk_idx = GetGfx11InputBlkIdx(); #else const auto blk_idx = GetBlkIdx(); #endif @@ -1983,7 +2025,7 @@ struct XdlopsGemm { const auto laneId = GetLaneId(); #if defined(__gfx11__) - const auto blk_idx = GetGfx11InputBlkIdx(); + const auto blk_idx = GetGfx11InputBlkIdx(); #else const auto blk_idx = GetBlkIdx(); #endif diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 99538ac78c..64327d0142 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -192,7 +192,6 @@ template __host__ __device__ constexpr Y type_convert_sp(X x) { static_assert(!ck::is_reference_v && !ck::is_reference_v); - return static_cast(x); } @@ -244,6 +243,41 @@ inline __host__ __device__ constexpr half_t type_convert_sp(int x) return u.fp16; } +template <> +inline __host__ __device__ constexpr int type_convert_sp(bhalf_t x) +{ + union + { + bhalf_t fp16; + int int32; + } u = {x}; + + return u.int32; +} + +template <> +inline __host__ __device__ constexpr bhalf_t type_convert_sp(int x) +{ + union + { + int int32; + bhalf_t fp16; + } u = {x}; + + return u.fp16; +} + +template <> +inline __host__ __device__ constexpr bhalf_t type_convert_sp(float x) +{ + return type_convert(x); +} + +template <> +inline __host__ __device__ constexpr half_t type_convert_sp(float x) +{ + return type_convert(x); +} // Declare a template function for fp8 conversion using SR template __host__ __device__ constexpr Y f8_convert_sr(X x); From 1acd8e041c8a7b2c95f6f7bf60ee9417eecb924a Mon Sep 17 00:00:00 2001 From: SamiAario-AMD Date: Thu, 4 Sep 2025 14:33:44 +0300 Subject: [PATCH 017/404] [CK Tile] gemm splitk two stage (#2697) * Fix a typo * Use std::variant to call run_gemm_example_with_layouts with the available layout variant combinations * Use a unified run_gemm_example_prec_type for basic gemm and universal gemm * Factor out run_gemm_example_prec_type * Refactor argument parsing in gemm_splitk_two_stage_reduce.cpp * Parse arguments outside of create_args * Move the gemm operators to separate structs to facilitate their reuse * Move the invokers to separate files to facilitate their reuse * Rename the invoker files for consistency with the examples that use them * Add fp32 support to the elementwise examples, and produce an error message for unsupported types * Get rid of four unused variables * Make two variables const * Add support for different input-output type combinations in elementwise examples * Test support for different input and output types in elementwise examples * Add support for different operations in the elementwise unary tests * Add support for UnaryConvert in the elementwise unary tests * Add support for bf16 in elementwise examples, excluding unsupported type combinations * Make some operator parameters const in ElementWiseKernel * Remove some unnecessary include statements * Implement a two-stage GEMM that does a type conversion in the second stage using the elementwise kernel * Clear workspace instead of output when flushing the cache in SplitKTwoStageInvoker::gemm * Fix formatting issues reported by clang * Add back CK_TILE_USE_WMMA related changes * Use the right prec type for bf16 in the universal GEMM and two stage split K examples * Add some brackets * Add some brackets * Separate the clearing of the GEMM output memory from the cache flushing in the universal GEMM example * Separate the clearing of the GEMM output memory from the cache flushing in the split K two stage example * Fix formatting * No need to call SetZero on ws_m_n_dev_buf here, as clear_gemm_output now does this as part of the kernel preprocessing * Add fp16 data type to splitk two stage example * Add preprocessing with optional cache flushing and clearing of output for k_batch > 1 to the basic GEMM example --- example/ck_tile/03_gemm/CMakeLists.txt | 2 + example/ck_tile/03_gemm/gemm_basic.cpp | 223 +++----------- .../ck_tile/03_gemm/gemm_basic_invoker.hpp | 176 +++++++++++ .../ck_tile/03_gemm/gemm_splitk_two_stage.cpp | 52 ++++ .../03_gemm/gemm_splitk_two_stage_invoker.hpp | 259 +++++++++++++++++ .../03_gemm/gemm_splitk_two_stage_reduce.cpp | 56 ++-- example/ck_tile/03_gemm/gemm_utils.hpp | 6 +- .../03_gemm/gemm_weight_preshuffle.cpp | 198 +------------ .../gemm_weight_preshuffle_invoker.hpp | 204 +++++++++++++ example/ck_tile/03_gemm/run_gemm_example.inc | 65 ++--- .../03_gemm/run_gemm_example_common.hpp | 64 ++++ example/ck_tile/03_gemm/universal_gemm.cpp | 273 +----------------- .../03_gemm/universal_gemm_invoker.hpp | 197 +++++++++++++ .../21_elementwise/elementwise_common.hpp | 26 ++ .../21_elementwise/elementwise_example.cpp | 44 +-- .../elementwise_example_add_4d.cpp | 37 ++- .../elementwise_example_transpose.cpp | 27 +- .../elementwise_example_unary.cpp | 90 ++++-- .../binary_elementwise_operation.hpp | 8 + .../elementwise/kernel/elementwise_kernel.hpp | 6 +- .../unary_element_wise_operation.hpp | 14 +- 21 files changed, 1245 insertions(+), 782 deletions(-) create mode 100644 example/ck_tile/03_gemm/gemm_basic_invoker.hpp create mode 100644 example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp create mode 100644 example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp create mode 100644 example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp create mode 100644 example/ck_tile/03_gemm/run_gemm_example_common.hpp create mode 100644 example/ck_tile/03_gemm/universal_gemm_invoker.hpp create mode 100644 example/ck_tile/21_elementwise/elementwise_common.hpp diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 825cd6e522..d2112a67bf 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -2,6 +2,7 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) add_executable(tile_example_gemm_reduce EXCLUDE_FROM_ALL gemm_splitk_two_stage_reduce.cpp) +add_executable(tile_example_gemm_splitk_two_stage EXCLUDE_FROM_ALL gemm_splitk_two_stage.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) @@ -16,3 +17,4 @@ target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OP target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_splitk_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 99c943a7f1..d687e35f5d 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -2,185 +2,9 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_utils.hpp" - -template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) - -{ - if constexpr(Persistent) - std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; - - // This part comes from the Codegen - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 64; - -#if CK_TILE_USE_WMMA - constexpr ck_tile::index_t M_Warp = 4; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 16; - constexpr ck_tile::index_t N_Warp_Tile = 16; - constexpr ck_tile::index_t K_Warp_Tile = 16; -#else - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; -#endif - - using CodegenGemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - - using CodegenGemmTraits = ck_tile::TileGemmTraits; - - using CodegenPipelineProblem = ck_tile:: - GemmPipelineProblem; - - using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - CodegenPipelineProblem::TransposeC, - memory_operation>>; - - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - }; - - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } -} - #include "run_gemm_example.inc" - -template -int run_gemm_example_prec_type(std::string a_layout, - std::string b_layout, - ck_tile::ArgParser& arg_parser) -{ - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if constexpr(std::is_same_v) - { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Row{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Col{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices when " - "BPrecType is ck_tile::pk_int4_t!"); - } - } - else - { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Row{}, Col{}, Row{}); - } - else if(a_layout == "R" && b_layout == "R") - { - return run_gemm_example_with_layouts( - arg_parser, Row{}, Row{}, Row{}); - } - else if(a_layout == "C" && b_layout == "R") - { - return run_gemm_example_with_layouts( - arg_parser, Col{}, Row{}, Row{}); - } - else if(a_layout == "C" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Col{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } - } -} +#include "run_gemm_example_common.hpp" +#include "gemm_basic_invoker.hpp" int run_gemm_example(ck_tile::ArgParser& arg_parser) { @@ -188,36 +12,53 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); + using GemmConfig = GemmConfigBase; + using Invoker = BasicInvoker; + if(data_type == "fp16") { - return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); + return run_gemm_example_prec_type( + a_layout, b_layout, arg_parser); } else if(data_type == "bf16") { - return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); + return run_gemm_example_prec_type( + a_layout, b_layout, arg_parser); } else if(data_type == "fp8") { - return run_gemm_example_prec_type( - a_layout, b_layout, arg_parser); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "bf8") { - return run_gemm_example_prec_type( - a_layout, b_layout, arg_parser); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "i8") { - return run_gemm_example_prec_type( - a_layout, b_layout, arg_parser); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "pk_int4_t") { // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) { - return run_gemm_example_prec_type( - a_layout, b_layout, arg_parser); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else { @@ -232,7 +73,9 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + auto arg_parser = create_args(); + auto result = arg_parser.parse(argc, argv); + if(!result) return -1; diff --git a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp new file mode 100644 index 0000000000..861374e268 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "gemm_utils.hpp" + +struct BasicInvoker +{ + template + static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + { + if constexpr(Persistent) + { + std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; + } + + // This part comes from the Codegen + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + +#if CK_TILE_USE_WMMA + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 16; + constexpr ck_tile::index_t N_Warp_Tile = 16; + constexpr ck_tile::index_t K_Warp_Tile = 16; +#else + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; +#endif + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = ck_tile::TileGemmTraits; + + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; + + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + CodegenPipelineProblem::TransposeC, + memory_operation>>; + + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; + + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = + std::make_unique>( + kargs.as_ptr[0], + kargs.bs_ptr[0], + s.rotating_count_, + size_a_buffer, + size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; + } + else + { + preprocess = clear_gemm_output; + } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + }; + + if(args.k_batch == 1) + { + return Run(MemoryOpSet{}); + } + else + { + return Run(MemoryOpAtomicAdd{}); + } + } +}; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp new file mode 100644 index 0000000000..0455e8e34d --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" +#include "run_gemm_example_common.hpp" +#include "gemm_splitk_two_stage_invoker.hpp" + +int run_gemm_example(ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + using Invoker = SplitKTwoStageInvoker; + + if(data_type == "fp16") + { + return run_gemm_example_prec_type, + Invoker, + ck_tile::half_t>(a_layout, b_layout, arg_parser); + } + else if(data_type == "bf16") + { + return run_gemm_example_prec_type, + Invoker, + ck_tile::bf16_t>(a_layout, b_layout, arg_parser); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } +} + +int main(int argc, char* argv[]) +{ + auto arg_parser = create_args(); + auto result = arg_parser.parse(argc, argv); + + if(!result) + return -1; + + try + { + return !run_gemm_example(arg_parser); + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } +} diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp new file mode 100644 index 0000000000..21867816e2 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "gemm_utils.hpp" +#include "ck_tile/ops/elementwise.hpp" + +template +struct GemmConfigTwoStage : public GemmConfigComputeV3 +{ + using WorkspaceType = ck_tile::remove_cvref_t; +}; + +struct SplitKTwoStageInvoker +{ + template + static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + + { + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using WorkspaceType = ck_tile::remove_cvref_t; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using GemmKernel = ck_tile::GemmKernel; + + ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType)); + ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); + auto c_ptr = ws_args.c_ptr; + ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); + auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); + + const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) + : GemmKernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernel::BlockSize(); + + if(!GemmKernel::IsSupportedArgument(gemm_kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using BlockTile = ck_tile::sequence<2048>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; + + using ElementwiseShape = + ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + using ElementwiseKernel = + ck_tile::ElementWiseKernel; + + ck_tile::index_t total_elements = 1; + std::vector shape = {args.M, args.N}; + + for(auto d : shape) + total_elements *= d; + + constexpr ck_tile::index_t kBlockSize = + ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); + constexpr ck_tile::index_t kBlockPerCu = 1; + + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = + (total_elements + elements_per_block - 1) / elements_per_block; + + auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.c_ptr)); + auto input_size = ck_tile::make_tuple(args.M, args.N); + + // Check if the kernel configuration is supported + if(!ElementwiseKernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; + + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_)); + }; + + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = + std::make_unique>( + gemm_kargs.as_ptr[0], + gemm_kargs.bs_ptr[0], + s.rotating_count_, + size_a_buffer, + size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; + } + else + { + preprocess = clear_gemm_output; + } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel( + GemmKernel{}, grids, blocks, 0, gemm_kargs), + ck_tile::make_kernel(ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(args.N, 1), // Input Stride + ck_tile::make_tuple(args.N, 1), // Output Stride + input_tensors, + static_cast(c_ptr))); + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, tail_number_, MemoryOpSet{}); + } + else + { + Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; + } +}; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index f42135a0b5..324dfc069a 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -608,16 +608,11 @@ template -int run_gemm_example_with_layouts_two_stage(int argc, - char* argv[], +int run_gemm_example_with_layouts_two_stage(ck_tile::ArgParser& arg_parser, const ALayout a_layout = ALayout{}, const BLayout b_layout = BLayout{}, [[maybe_unused]] const CLayout c_layout = CLayout{}) { - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - using AccDataType = typename GemmTypeConfig::AccDataType; ck_tile::index_t M = arg_parser.get_int("m"); @@ -837,12 +832,13 @@ template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +int run_gemm_example_prec_type(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) { - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - auto [result, arg_parser] = create_args(argc, argv); - bool preshuffle = GemmConfig::Preshuffle; + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + bool preshuffle = GemmConfig::Preshuffle; if(preshuffle && std::is_same_v) { @@ -866,7 +862,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a CPrecType, Row, Col, - Row>(argc, argv, Row{}, Col{}, Row{}); + Row>(arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { @@ -876,7 +872,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a CPrecType, Col, Col, - Row>(argc, argv, Col{}, Col{}, Row{}); + Row>(arg_parser, Col{}, Col{}, Row{}); } else { @@ -892,7 +888,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a APrecType, BPrecType, CPrecType>( - argc, argv, Row{}, Row{}, Row{}); + arg_parser, Row{}, Row{}, Row{}); } if(a_layout == "R" && b_layout == "C") { @@ -900,7 +896,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a APrecType, BPrecType, CPrecType>( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { @@ -908,7 +904,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a APrecType, BPrecType, CPrecType>( - argc, argv, Col{}, Row{}, Row{}); + arg_parser, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { @@ -916,7 +912,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a APrecType, BPrecType, CPrecType>( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -927,12 +923,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } template